data-gen / tests /test_user_randomizer_grader.py
ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
from __future__ import annotations
import argparse
import csv
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from conv_data_gen.llm import LLMClient
from conv_data_gen.config import config
def _default_input_dir() -> Path:
return Path(__file__).resolve().parent / "output" / "user_randomizer"
def _default_output_csv(input_dir: Path) -> Path:
return input_dir / "graded_knobs_report.csv"
def _list_conversation_jsons(input_dir: Path) -> List[Path]:
return sorted(input_dir.glob("persona_*.json"))
def _load_transcript_and_actual(path: Path) -> Tuple[str, Dict[str, Any]]:
obj = json.loads(path.read_text(encoding="utf-8"))
turns = obj.get("turns") or []
actual = obj.get("user_attributes") or {}
lines: List[str] = []
for t in turns:
try:
role = str(t.get("role") or "")
content = str(t.get("content") or "")
if role and content:
lines.append(f"{role}: {content}")
except Exception:
continue
transcript = "\n".join(lines)
return transcript, actual
def _read_system_prompt(path: Path) -> str:
try:
return path.read_text(encoding="utf-8").strip()
except Exception:
return ""
def _default_system_path() -> Path:
return Path(__file__).resolve().parent / "assets" / "grader_system.txt"
def _build_grading_prompt(
transcript: str,
knob_keys: List[str],
lang_keys: List[str],
system_prompt_path: Path,
) -> Tuple[str, List[Dict[str, str]]]:
system = _read_system_prompt(system_prompt_path) or (
"You are an evaluator. Respond ONLY as JSON with keys 'knobs' and "
"'language_style'."
)
# Build a JSON skeleton to steer structured output
knob_entries = (
",\n ".join([f'"{k}": ""' for k in knob_keys]) if knob_keys else ""
)
lang_entries = (
",\n ".join([f'"{k}": ""' for k in lang_keys]) if lang_keys else ""
)
skeleton = (
'{\n"knobs": {\n '
+ knob_entries
+ "\n},\n"
+ '"language_style": {\n '
+ lang_entries
+ "\n}\n}"
)
fields_desc = (
"Infer these fields (fill values in this JSON skeleton exactly):\n\n"
+ skeleton
)
user = fields_desc + "\n\nTRANSCRIPT (role: text):\n\n" + transcript
messages = [{"role": "user", "content": user}]
return system, messages
def _grade_one(
json_path: Path,
llm: LLMClient,
system_prompt_path: Path,
) -> Dict[str, Any]:
transcript, actual = _load_transcript_and_actual(json_path)
# Derive fields to grade from actual metadata (names only)
knob_keys = sorted(list((actual.get("knobs") or {}).keys()))
lang_keys = sorted(list((actual.get("language_style") or {}).keys()))
# Fallback: if no keys present, use defaults from persona_knobs.yaml
if not knob_keys:
try:
import yaml # type: ignore
ypath = config.paths.PERSONA_KNOBS_YAML
loaded = yaml.safe_load(ypath.read_text(encoding="utf-8"))
if isinstance(loaded, dict):
knobs_map = loaded.get("knobs") or {}
if isinstance(knobs_map, dict):
knob_keys = sorted([str(k) for k in knobs_map.keys()])
except Exception:
knob_keys = knob_keys
if not lang_keys:
lang_keys = [
"language",
"directness",
"formality",
"code_switch_ratio",
"regionalisms",
]
system, messages = _build_grading_prompt(
transcript, knob_keys, lang_keys, system_prompt_path
)
resp = llm.get_llm_response_json(
messages=messages,
model=config.models.USER_MODEL,
system_prompt=system,
temperature=config.models.USER_MODEL_TEMPERATURE,
)
text = (resp.get("text") or "").strip()
inferred: Dict[str, Any] = {}
if text:
# Try tolerant parse via client helper first
parsed = llm.safe_parse_json(text)
if isinstance(parsed, dict):
inferred = parsed
else:
# Try to extract the largest JSON object
try:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
inferred = json.loads(text[start : end + 1])
except Exception:
inferred = {}
return {
"path": str(json_path),
"actual": actual,
"inferred": inferred,
}
def _flatten_for_csv(
item: Dict[str, Any],
knobs_order: List[str],
lang_order: List[str],
) -> Dict[str, str]:
path = Path(item.get("path") or "")
persona_idx = path.stem.split("_")[-1] if path.stem else ""
actual = item.get("actual") or {}
inferred = item.get("inferred") or {}
out: Dict[str, str] = {
"persona_index": persona_idx,
"conversation_base": str(path.with_suffix("").as_posix()),
}
# Pair columns per key: actual then graded
a_knobs = actual.get("knobs") or {}
i_knobs = inferred.get("knobs") or {}
for k in knobs_order:
out[f"actual.knobs.{k}"] = str(a_knobs.get(k, ""))
out[f"graded.knobs.{k}"] = str(i_knobs.get(k, ""))
a_lang = actual.get("language_style") or {}
i_lang = inferred.get("language_style") or {}
for k in lang_order:
out[f"actual.language_style.{k}"] = str(a_lang.get(k, ""))
out[f"graded.language_style.{k}"] = str(i_lang.get(k, ""))
return out
def run_grader(
input_dir: Path,
out_csv: Path,
system_prompt_path: Path,
) -> Path:
input_dir.mkdir(parents=True, exist_ok=True)
json_paths = _list_conversation_jsons(input_dir)
if not json_paths:
base = str(input_dir)
msg = f"No persona_*.json found in {base}"
raise FileNotFoundError(msg)
llm = LLMClient()
rows: List[Dict[str, Any]] = []
# Moderate concurrency
cpu = os.cpu_count() or 2
max_workers = max(1, min(16, cpu * 2))
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [
ex.submit(_grade_one, p, llm, system_prompt_path)
for p in json_paths
]
for fut in as_completed(futures):
try:
res = fut.result()
if isinstance(res, dict):
rows.append(res)
except Exception:
continue
# Build union key orders for paired columns
knobs_union: List[str] = []
lang_union: List[str] = []
seen_knobs: set = set()
seen_lang: set = set()
for r in rows:
a = r.get("actual") or {}
i = r.get("inferred") or {}
a_kn = (a.get("knobs") or {}).keys()
i_kn = (i.get("knobs") or {}).keys()
for k in sorted(set(list(a_kn) + list(i_kn))):
if k not in seen_knobs:
seen_knobs.add(k)
knobs_union.append(k)
a_ls = (a.get("language_style") or {}).keys()
i_ls = (i.get("language_style") or {}).keys()
for k in sorted(set(list(a_ls) + list(i_ls))):
if k not in seen_lang:
seen_lang.add(k)
lang_union.append(k)
flat_rows: List[Dict[str, str]] = [
_flatten_for_csv(r, knobs_union, lang_union) for r in rows
]
# Deterministic header order with pairs
headers: List[str] = [
"persona_index",
"conversation_base",
]
for k in knobs_union:
headers.append(f"actual.knobs.{k}")
headers.append(f"graded.knobs.{k}")
for k in lang_union:
headers.append(f"actual.language_style.{k}")
headers.append(f"graded.language_style.{k}")
out_csv.parent.mkdir(parents=True, exist_ok=True)
with open(out_csv, "w", encoding="utf-8", newline="") as f:
writer = csv.DictWriter(f, fieldnames=headers)
writer.writeheader()
for r in sorted(
flat_rows, key=lambda x: int(x.get("persona_index", "0") or 0)
):
writer.writerow(r)
return out_csv
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description=(
"Grade user knobs from transcripts only and compare to actual."
)
)
p.add_argument(
"--input",
type=Path,
default=_default_input_dir(),
help="Input directory containing persona_*.json",
)
p.add_argument(
"--out",
type=Path,
default=None,
help="Output CSV path (defaults to graded_knobs_report.csv)",
)
p.add_argument(
"--system",
type=Path,
default=_default_system_path(),
help="System prompt file for grading",
)
return p.parse_args()
if __name__ == "__main__":
args = _parse_args()
out_csv = args.out or _default_output_csv(args.input)
path = run_grader(args.input, out_csv, args.system)
print(str(path))