from __future__ import annotations import argparse import csv import json import os from pathlib import Path from typing import Any, Dict, List, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed from conv_data_gen.llm import LLMClient from conv_data_gen.config import config def _default_input_dir() -> Path: return Path(__file__).resolve().parent / "output" / "user_randomizer" def _default_output_csv(input_dir: Path) -> Path: return input_dir / "graded_knobs_report.csv" def _list_conversation_jsons(input_dir: Path) -> List[Path]: return sorted(input_dir.glob("persona_*.json")) def _load_transcript_and_actual(path: Path) -> Tuple[str, Dict[str, Any]]: obj = json.loads(path.read_text(encoding="utf-8")) turns = obj.get("turns") or [] actual = obj.get("user_attributes") or {} lines: List[str] = [] for t in turns: try: role = str(t.get("role") or "") content = str(t.get("content") or "") if role and content: lines.append(f"{role}: {content}") except Exception: continue transcript = "\n".join(lines) return transcript, actual def _read_system_prompt(path: Path) -> str: try: return path.read_text(encoding="utf-8").strip() except Exception: return "" def _default_system_path() -> Path: return Path(__file__).resolve().parent / "assets" / "grader_system.txt" def _build_grading_prompt( transcript: str, knob_keys: List[str], lang_keys: List[str], system_prompt_path: Path, ) -> Tuple[str, List[Dict[str, str]]]: system = _read_system_prompt(system_prompt_path) or ( "You are an evaluator. Respond ONLY as JSON with keys 'knobs' and " "'language_style'." ) # Build a JSON skeleton to steer structured output knob_entries = ( ",\n ".join([f'"{k}": ""' for k in knob_keys]) if knob_keys else "" ) lang_entries = ( ",\n ".join([f'"{k}": ""' for k in lang_keys]) if lang_keys else "" ) skeleton = ( '{\n"knobs": {\n ' + knob_entries + "\n},\n" + '"language_style": {\n ' + lang_entries + "\n}\n}" ) fields_desc = ( "Infer these fields (fill values in this JSON skeleton exactly):\n\n" + skeleton ) user = fields_desc + "\n\nTRANSCRIPT (role: text):\n\n" + transcript messages = [{"role": "user", "content": user}] return system, messages def _grade_one( json_path: Path, llm: LLMClient, system_prompt_path: Path, ) -> Dict[str, Any]: transcript, actual = _load_transcript_and_actual(json_path) # Derive fields to grade from actual metadata (names only) knob_keys = sorted(list((actual.get("knobs") or {}).keys())) lang_keys = sorted(list((actual.get("language_style") or {}).keys())) # Fallback: if no keys present, use defaults from persona_knobs.yaml if not knob_keys: try: import yaml # type: ignore ypath = config.paths.PERSONA_KNOBS_YAML loaded = yaml.safe_load(ypath.read_text(encoding="utf-8")) if isinstance(loaded, dict): knobs_map = loaded.get("knobs") or {} if isinstance(knobs_map, dict): knob_keys = sorted([str(k) for k in knobs_map.keys()]) except Exception: knob_keys = knob_keys if not lang_keys: lang_keys = [ "language", "directness", "formality", "code_switch_ratio", "regionalisms", ] system, messages = _build_grading_prompt( transcript, knob_keys, lang_keys, system_prompt_path ) resp = llm.get_llm_response_json( messages=messages, model=config.models.USER_MODEL, system_prompt=system, temperature=config.models.USER_MODEL_TEMPERATURE, ) text = (resp.get("text") or "").strip() inferred: Dict[str, Any] = {} if text: # Try tolerant parse via client helper first parsed = llm.safe_parse_json(text) if isinstance(parsed, dict): inferred = parsed else: # Try to extract the largest JSON object try: start = text.find("{") end = text.rfind("}") if start != -1 and end != -1 and end > start: inferred = json.loads(text[start : end + 1]) except Exception: inferred = {} return { "path": str(json_path), "actual": actual, "inferred": inferred, } def _flatten_for_csv( item: Dict[str, Any], knobs_order: List[str], lang_order: List[str], ) -> Dict[str, str]: path = Path(item.get("path") or "") persona_idx = path.stem.split("_")[-1] if path.stem else "" actual = item.get("actual") or {} inferred = item.get("inferred") or {} out: Dict[str, str] = { "persona_index": persona_idx, "conversation_base": str(path.with_suffix("").as_posix()), } # Pair columns per key: actual then graded a_knobs = actual.get("knobs") or {} i_knobs = inferred.get("knobs") or {} for k in knobs_order: out[f"actual.knobs.{k}"] = str(a_knobs.get(k, "")) out[f"graded.knobs.{k}"] = str(i_knobs.get(k, "")) a_lang = actual.get("language_style") or {} i_lang = inferred.get("language_style") or {} for k in lang_order: out[f"actual.language_style.{k}"] = str(a_lang.get(k, "")) out[f"graded.language_style.{k}"] = str(i_lang.get(k, "")) return out def run_grader( input_dir: Path, out_csv: Path, system_prompt_path: Path, ) -> Path: input_dir.mkdir(parents=True, exist_ok=True) json_paths = _list_conversation_jsons(input_dir) if not json_paths: base = str(input_dir) msg = f"No persona_*.json found in {base}" raise FileNotFoundError(msg) llm = LLMClient() rows: List[Dict[str, Any]] = [] # Moderate concurrency cpu = os.cpu_count() or 2 max_workers = max(1, min(16, cpu * 2)) with ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ ex.submit(_grade_one, p, llm, system_prompt_path) for p in json_paths ] for fut in as_completed(futures): try: res = fut.result() if isinstance(res, dict): rows.append(res) except Exception: continue # Build union key orders for paired columns knobs_union: List[str] = [] lang_union: List[str] = [] seen_knobs: set = set() seen_lang: set = set() for r in rows: a = r.get("actual") or {} i = r.get("inferred") or {} a_kn = (a.get("knobs") or {}).keys() i_kn = (i.get("knobs") or {}).keys() for k in sorted(set(list(a_kn) + list(i_kn))): if k not in seen_knobs: seen_knobs.add(k) knobs_union.append(k) a_ls = (a.get("language_style") or {}).keys() i_ls = (i.get("language_style") or {}).keys() for k in sorted(set(list(a_ls) + list(i_ls))): if k not in seen_lang: seen_lang.add(k) lang_union.append(k) flat_rows: List[Dict[str, str]] = [ _flatten_for_csv(r, knobs_union, lang_union) for r in rows ] # Deterministic header order with pairs headers: List[str] = [ "persona_index", "conversation_base", ] for k in knobs_union: headers.append(f"actual.knobs.{k}") headers.append(f"graded.knobs.{k}") for k in lang_union: headers.append(f"actual.language_style.{k}") headers.append(f"graded.language_style.{k}") out_csv.parent.mkdir(parents=True, exist_ok=True) with open(out_csv, "w", encoding="utf-8", newline="") as f: writer = csv.DictWriter(f, fieldnames=headers) writer.writeheader() for r in sorted( flat_rows, key=lambda x: int(x.get("persona_index", "0") or 0) ): writer.writerow(r) return out_csv def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description=( "Grade user knobs from transcripts only and compare to actual." ) ) p.add_argument( "--input", type=Path, default=_default_input_dir(), help="Input directory containing persona_*.json", ) p.add_argument( "--out", type=Path, default=None, help="Output CSV path (defaults to graded_knobs_report.csv)", ) p.add_argument( "--system", type=Path, default=_default_system_path(), help="System prompt file for grading", ) return p.parse_args() if __name__ == "__main__": args = _parse_args() out_csv = args.out or _default_output_csv(args.input) path = run_grader(args.input, out_csv, args.system) print(str(path))