data-gen / tests /test_user_randomizer.py
ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import json
import argparse
import os
from pathlib import Path
from typing import Any, Dict, List
from conv_data_gen.generators.conversation.core_simulator import (
ConversationSimulatorWithTools,
)
from conv_data_gen.generators.user.user_persona_generator import (
UserPersonaGenerator,
PersonaSample,
)
from concurrent.futures import ThreadPoolExecutor, as_completed
def _project_root() -> Path:
return Path(__file__).resolve().parent.parent
def _default_assets_dir() -> Path:
return _project_root() / "tests" / "assets"
def _default_output_dir() -> Path:
out = _project_root() / "tests" / "output" / "user_randomizer"
out.mkdir(parents=True, exist_ok=True)
return out
def _load_proxy_json(path: Path) -> Dict[str, str]:
obj = json.loads(path.read_text(encoding="utf-8"))
if isinstance(obj, dict):
return {str(k): str(v) for k, v in obj.items()}
raise ValueError("Proxy JSON must be an object/dict")
def _read_text_file(path: Path) -> str:
return path.read_text(encoding="utf-8")
def run_user_randomizer_test(
n_variants: int,
bot_prompt_path: Path,
proxy_json_path: Path,
out_dir: Path,
) -> Path:
out_dir.mkdir(parents=True, exist_ok=True)
proxy = _load_proxy_json(proxy_json_path)
bot_text = _read_text_file(bot_prompt_path)
# Single generator reused for sampling + formatting
gen = UserPersonaGenerator(
company_name="sarvam-finance",
stakeholder="customer",
use_case="loan",
bot_prompt_text=bot_text,
)
mapping: List[Dict[str, Any]] = []
def _work(i: int) -> Dict[str, Any]:
# Sample attributes
sample = gen._randomizer.sample_persona()
# Compose persona text using reusable generator
proxy_for_format = {str(k): str(v) for k, v in proxy.items()}
persona = PersonaSample(
knobs=sample.get("knobs", {}),
knob_descriptions=sample.get("knob_descriptions", {}),
language_style=sample.get("language_style", {}),
language_descriptions=sample.get("language_descriptions", {}),
demographics=sample.get("demographics", {}),
demographic_descriptions=sample.get(
"demographic_descriptions", {}
),
interaction=sample.get("interaction", {}),
)
user_text = gen._compose_user_message_for_proxy(
persona, proxy_for_format
)
# Persist persona
# Run simulation
sim = ConversationSimulatorWithTools(
user_prompt_path=None,
user_prompt_text=user_text,
bot_prompt_path=str(bot_prompt_path),
max_turns=10,
output_dir=str(out_dir),
base_filename=f"persona_{i}",
conversation_direction="agent_to_user",
)
transcript = sim.generate()
conv_base = (out_dir / f"persona_{i}").as_posix()
# Prepend sampled knobs to the conversation text file
try:
conv_txt_path = out_dir / f"persona_{i}.txt"
if conv_txt_path.exists():
txt = conv_txt_path.read_text(encoding="utf-8")
knobs_lines: List[str] = ["USER KNOBS (sampled):"]
for k, v in sample.get("knobs", {}).items():
knobs_lines.append(f"- {k}: {v}")
header = "\n".join(knobs_lines) + "\n\n"
conv_txt_path.write_text(header + txt, encoding="utf-8")
except Exception:
pass
# Enrich conversation JSON with user attributes
conv_json_path = out_dir / f"persona_{i}.json"
try:
if conv_json_path.exists():
meta = json.loads(conv_json_path.read_text(encoding="utf-8"))
meta["user_attributes"] = {
"knobs": sample.get("knobs", {}),
"language_style": sample.get("language_style", {}),
"demographics": sample.get("demographics", {}),
"interaction": sample.get("interaction", {}),
}
conv_json_path.write_text(
json.dumps(meta, ensure_ascii=False, indent=2),
encoding="utf-8",
)
except Exception:
pass
return {
"persona_index": i,
"conversation_base": conv_base,
"user_prompt_path": "",
"bot_prompt_path": str(bot_prompt_path),
"sampled_attributes": {
"knobs": sample.get("knobs", {}),
"language_style": sample.get("language_style", {}),
"demographics": sample.get("demographics", {}),
"interaction": sample.get("interaction", {}),
},
"total_turns": len(transcript),
}
cpu = os.cpu_count() or 2
max_workers = max(1, min(16, cpu * 4))
workers = min(max_workers, max(1, n_variants))
futures = []
with ThreadPoolExecutor(max_workers=workers) as ex:
for i in range(1, n_variants + 1):
futures.append(ex.submit(_work, i))
for fut in as_completed(futures):
try:
res = fut.result()
if isinstance(res, dict):
mapping.append(res)
except Exception:
# Skip failed runs but continue others
pass
# Stable order by persona index
mapping.sort(key=lambda r: int(r.get("persona_index", 0)))
# Write compact JSON map: conversation file base -> user attributes
map_path = out_dir / "conversation_user_attributes_map.json"
with open(map_path, "w", encoding="utf-8") as f:
json.dump(mapping, f, ensure_ascii=False, indent=2)
return map_path
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description=(
"Generate conversations to test the effect of user attributes."
)
)
p.add_argument("--num", type=int, default=10, help="Number of variants")
p.add_argument(
"--bot",
type=Path,
default=_default_assets_dir() / "bot.txt",
help="Path to bot prompt text file",
)
p.add_argument(
"--proxy",
type=Path,
default=_default_assets_dir() / "proxy.json",
help="Path to base user proxy JSON file",
)
p.add_argument(
"--out",
type=Path,
default=_default_output_dir(),
help="Output directory for personas, conversations, and map",
)
return p.parse_args()
if __name__ == "__main__":
args = _parse_args()
path = run_user_randomizer_test(
n_variants=max(1, int(args.num)),
bot_prompt_path=args.bot,
proxy_json_path=args.proxy,
out_dir=args.out,
)
print(str(path))