Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import json | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| from conv_data_gen.generators.conversation.core_simulator import ( | |
| ConversationSimulatorWithTools, | |
| ) | |
| from conv_data_gen.generators.user.user_persona_generator import ( | |
| UserPersonaGenerator, | |
| PersonaSample, | |
| ) | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| def _project_root() -> Path: | |
| return Path(__file__).resolve().parent.parent | |
| def _default_assets_dir() -> Path: | |
| return _project_root() / "tests" / "assets" | |
| def _default_output_dir() -> Path: | |
| out = _project_root() / "tests" / "output" / "user_randomizer" | |
| out.mkdir(parents=True, exist_ok=True) | |
| return out | |
| def _load_proxy_json(path: Path) -> Dict[str, str]: | |
| obj = json.loads(path.read_text(encoding="utf-8")) | |
| if isinstance(obj, dict): | |
| return {str(k): str(v) for k, v in obj.items()} | |
| raise ValueError("Proxy JSON must be an object/dict") | |
| def _read_text_file(path: Path) -> str: | |
| return path.read_text(encoding="utf-8") | |
| def run_user_randomizer_test( | |
| n_variants: int, | |
| bot_prompt_path: Path, | |
| proxy_json_path: Path, | |
| out_dir: Path, | |
| ) -> Path: | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| proxy = _load_proxy_json(proxy_json_path) | |
| bot_text = _read_text_file(bot_prompt_path) | |
| # Single generator reused for sampling + formatting | |
| gen = UserPersonaGenerator( | |
| company_name="sarvam-finance", | |
| stakeholder="customer", | |
| use_case="loan", | |
| bot_prompt_text=bot_text, | |
| ) | |
| mapping: List[Dict[str, Any]] = [] | |
| def _work(i: int) -> Dict[str, Any]: | |
| # Sample attributes | |
| sample = gen._randomizer.sample_persona() | |
| # Compose persona text using reusable generator | |
| proxy_for_format = {str(k): str(v) for k, v in proxy.items()} | |
| persona = PersonaSample( | |
| knobs=sample.get("knobs", {}), | |
| knob_descriptions=sample.get("knob_descriptions", {}), | |
| language_style=sample.get("language_style", {}), | |
| language_descriptions=sample.get("language_descriptions", {}), | |
| demographics=sample.get("demographics", {}), | |
| demographic_descriptions=sample.get( | |
| "demographic_descriptions", {} | |
| ), | |
| interaction=sample.get("interaction", {}), | |
| ) | |
| user_text = gen._compose_user_message_for_proxy( | |
| persona, proxy_for_format | |
| ) | |
| # Persist persona | |
| # Run simulation | |
| sim = ConversationSimulatorWithTools( | |
| user_prompt_path=None, | |
| user_prompt_text=user_text, | |
| bot_prompt_path=str(bot_prompt_path), | |
| max_turns=10, | |
| output_dir=str(out_dir), | |
| base_filename=f"persona_{i}", | |
| conversation_direction="agent_to_user", | |
| ) | |
| transcript = sim.generate() | |
| conv_base = (out_dir / f"persona_{i}").as_posix() | |
| # Prepend sampled knobs to the conversation text file | |
| try: | |
| conv_txt_path = out_dir / f"persona_{i}.txt" | |
| if conv_txt_path.exists(): | |
| txt = conv_txt_path.read_text(encoding="utf-8") | |
| knobs_lines: List[str] = ["USER KNOBS (sampled):"] | |
| for k, v in sample.get("knobs", {}).items(): | |
| knobs_lines.append(f"- {k}: {v}") | |
| header = "\n".join(knobs_lines) + "\n\n" | |
| conv_txt_path.write_text(header + txt, encoding="utf-8") | |
| except Exception: | |
| pass | |
| # Enrich conversation JSON with user attributes | |
| conv_json_path = out_dir / f"persona_{i}.json" | |
| try: | |
| if conv_json_path.exists(): | |
| meta = json.loads(conv_json_path.read_text(encoding="utf-8")) | |
| meta["user_attributes"] = { | |
| "knobs": sample.get("knobs", {}), | |
| "language_style": sample.get("language_style", {}), | |
| "demographics": sample.get("demographics", {}), | |
| "interaction": sample.get("interaction", {}), | |
| } | |
| conv_json_path.write_text( | |
| json.dumps(meta, ensure_ascii=False, indent=2), | |
| encoding="utf-8", | |
| ) | |
| except Exception: | |
| pass | |
| return { | |
| "persona_index": i, | |
| "conversation_base": conv_base, | |
| "user_prompt_path": "", | |
| "bot_prompt_path": str(bot_prompt_path), | |
| "sampled_attributes": { | |
| "knobs": sample.get("knobs", {}), | |
| "language_style": sample.get("language_style", {}), | |
| "demographics": sample.get("demographics", {}), | |
| "interaction": sample.get("interaction", {}), | |
| }, | |
| "total_turns": len(transcript), | |
| } | |
| cpu = os.cpu_count() or 2 | |
| max_workers = max(1, min(16, cpu * 4)) | |
| workers = min(max_workers, max(1, n_variants)) | |
| futures = [] | |
| with ThreadPoolExecutor(max_workers=workers) as ex: | |
| for i in range(1, n_variants + 1): | |
| futures.append(ex.submit(_work, i)) | |
| for fut in as_completed(futures): | |
| try: | |
| res = fut.result() | |
| if isinstance(res, dict): | |
| mapping.append(res) | |
| except Exception: | |
| # Skip failed runs but continue others | |
| pass | |
| # Stable order by persona index | |
| mapping.sort(key=lambda r: int(r.get("persona_index", 0))) | |
| # Write compact JSON map: conversation file base -> user attributes | |
| map_path = out_dir / "conversation_user_attributes_map.json" | |
| with open(map_path, "w", encoding="utf-8") as f: | |
| json.dump(mapping, f, ensure_ascii=False, indent=2) | |
| return map_path | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description=( | |
| "Generate conversations to test the effect of user attributes." | |
| ) | |
| ) | |
| p.add_argument("--num", type=int, default=10, help="Number of variants") | |
| p.add_argument( | |
| "--bot", | |
| type=Path, | |
| default=_default_assets_dir() / "bot.txt", | |
| help="Path to bot prompt text file", | |
| ) | |
| p.add_argument( | |
| "--proxy", | |
| type=Path, | |
| default=_default_assets_dir() / "proxy.json", | |
| help="Path to base user proxy JSON file", | |
| ) | |
| p.add_argument( | |
| "--out", | |
| type=Path, | |
| default=_default_output_dir(), | |
| help="Output directory for personas, conversations, and map", | |
| ) | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| args = _parse_args() | |
| path = run_user_randomizer_test( | |
| n_variants=max(1, int(args.num)), | |
| bot_prompt_path=args.bot, | |
| proxy_json_path=args.proxy, | |
| out_dir=args.out, | |
| ) | |
| print(str(path)) | |