from __future__ import annotations import json import argparse import os from pathlib import Path from typing import Any, Dict, List from conv_data_gen.generators.conversation.core_simulator import ( ConversationSimulatorWithTools, ) from conv_data_gen.generators.user.user_persona_generator import ( UserPersonaGenerator, PersonaSample, ) from concurrent.futures import ThreadPoolExecutor, as_completed def _project_root() -> Path: return Path(__file__).resolve().parent.parent def _default_assets_dir() -> Path: return _project_root() / "tests" / "assets" def _default_output_dir() -> Path: out = _project_root() / "tests" / "output" / "user_randomizer" out.mkdir(parents=True, exist_ok=True) return out def _load_proxy_json(path: Path) -> Dict[str, str]: obj = json.loads(path.read_text(encoding="utf-8")) if isinstance(obj, dict): return {str(k): str(v) for k, v in obj.items()} raise ValueError("Proxy JSON must be an object/dict") def _read_text_file(path: Path) -> str: return path.read_text(encoding="utf-8") def run_user_randomizer_test( n_variants: int, bot_prompt_path: Path, proxy_json_path: Path, out_dir: Path, ) -> Path: out_dir.mkdir(parents=True, exist_ok=True) proxy = _load_proxy_json(proxy_json_path) bot_text = _read_text_file(bot_prompt_path) # Single generator reused for sampling + formatting gen = UserPersonaGenerator( company_name="sarvam-finance", stakeholder="customer", use_case="loan", bot_prompt_text=bot_text, ) mapping: List[Dict[str, Any]] = [] def _work(i: int) -> Dict[str, Any]: # Sample attributes sample = gen._randomizer.sample_persona() # Compose persona text using reusable generator proxy_for_format = {str(k): str(v) for k, v in proxy.items()} persona = PersonaSample( knobs=sample.get("knobs", {}), knob_descriptions=sample.get("knob_descriptions", {}), language_style=sample.get("language_style", {}), language_descriptions=sample.get("language_descriptions", {}), demographics=sample.get("demographics", {}), demographic_descriptions=sample.get( "demographic_descriptions", {} ), interaction=sample.get("interaction", {}), ) user_text = gen._compose_user_message_for_proxy( persona, proxy_for_format ) # Persist persona # Run simulation sim = ConversationSimulatorWithTools( user_prompt_path=None, user_prompt_text=user_text, bot_prompt_path=str(bot_prompt_path), max_turns=10, output_dir=str(out_dir), base_filename=f"persona_{i}", conversation_direction="agent_to_user", ) transcript = sim.generate() conv_base = (out_dir / f"persona_{i}").as_posix() # Prepend sampled knobs to the conversation text file try: conv_txt_path = out_dir / f"persona_{i}.txt" if conv_txt_path.exists(): txt = conv_txt_path.read_text(encoding="utf-8") knobs_lines: List[str] = ["USER KNOBS (sampled):"] for k, v in sample.get("knobs", {}).items(): knobs_lines.append(f"- {k}: {v}") header = "\n".join(knobs_lines) + "\n\n" conv_txt_path.write_text(header + txt, encoding="utf-8") except Exception: pass # Enrich conversation JSON with user attributes conv_json_path = out_dir / f"persona_{i}.json" try: if conv_json_path.exists(): meta = json.loads(conv_json_path.read_text(encoding="utf-8")) meta["user_attributes"] = { "knobs": sample.get("knobs", {}), "language_style": sample.get("language_style", {}), "demographics": sample.get("demographics", {}), "interaction": sample.get("interaction", {}), } conv_json_path.write_text( json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8", ) except Exception: pass return { "persona_index": i, "conversation_base": conv_base, "user_prompt_path": "", "bot_prompt_path": str(bot_prompt_path), "sampled_attributes": { "knobs": sample.get("knobs", {}), "language_style": sample.get("language_style", {}), "demographics": sample.get("demographics", {}), "interaction": sample.get("interaction", {}), }, "total_turns": len(transcript), } cpu = os.cpu_count() or 2 max_workers = max(1, min(16, cpu * 4)) workers = min(max_workers, max(1, n_variants)) futures = [] with ThreadPoolExecutor(max_workers=workers) as ex: for i in range(1, n_variants + 1): futures.append(ex.submit(_work, i)) for fut in as_completed(futures): try: res = fut.result() if isinstance(res, dict): mapping.append(res) except Exception: # Skip failed runs but continue others pass # Stable order by persona index mapping.sort(key=lambda r: int(r.get("persona_index", 0))) # Write compact JSON map: conversation file base -> user attributes map_path = out_dir / "conversation_user_attributes_map.json" with open(map_path, "w", encoding="utf-8") as f: json.dump(mapping, f, ensure_ascii=False, indent=2) return map_path def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description=( "Generate conversations to test the effect of user attributes." ) ) p.add_argument("--num", type=int, default=10, help="Number of variants") p.add_argument( "--bot", type=Path, default=_default_assets_dir() / "bot.txt", help="Path to bot prompt text file", ) p.add_argument( "--proxy", type=Path, default=_default_assets_dir() / "proxy.json", help="Path to base user proxy JSON file", ) p.add_argument( "--out", type=Path, default=_default_output_dir(), help="Output directory for personas, conversations, and map", ) return p.parse_args() if __name__ == "__main__": args = _parse_args() path = run_user_randomizer_test( n_variants=max(1, int(args.num)), bot_prompt_path=args.bot, proxy_json_path=args.proxy, out_dir=args.out, ) print(str(path))