Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from conv_data_gen.llm import LLMClient | |
| from conv_data_gen.config import config | |
| def _default_input_dir() -> Path: | |
| return Path(__file__).resolve().parent / "output" / "user_randomizer" | |
| def _default_output_csv(input_dir: Path) -> Path: | |
| return input_dir / "graded_knobs_report.csv" | |
| def _list_conversation_jsons(input_dir: Path) -> List[Path]: | |
| return sorted(input_dir.glob("persona_*.json")) | |
| def _load_transcript_and_actual(path: Path) -> Tuple[str, Dict[str, Any]]: | |
| obj = json.loads(path.read_text(encoding="utf-8")) | |
| turns = obj.get("turns") or [] | |
| actual = obj.get("user_attributes") or {} | |
| lines: List[str] = [] | |
| for t in turns: | |
| try: | |
| role = str(t.get("role") or "") | |
| content = str(t.get("content") or "") | |
| if role and content: | |
| lines.append(f"{role}: {content}") | |
| except Exception: | |
| continue | |
| transcript = "\n".join(lines) | |
| return transcript, actual | |
| def _read_system_prompt(path: Path) -> str: | |
| try: | |
| return path.read_text(encoding="utf-8").strip() | |
| except Exception: | |
| return "" | |
| def _default_system_path() -> Path: | |
| return Path(__file__).resolve().parent / "assets" / "grader_system.txt" | |
| def _build_grading_prompt( | |
| transcript: str, | |
| knob_keys: List[str], | |
| lang_keys: List[str], | |
| system_prompt_path: Path, | |
| ) -> Tuple[str, List[Dict[str, str]]]: | |
| system = _read_system_prompt(system_prompt_path) or ( | |
| "You are an evaluator. Respond ONLY as JSON with keys 'knobs' and " | |
| "'language_style'." | |
| ) | |
| # Build a JSON skeleton to steer structured output | |
| knob_entries = ( | |
| ",\n ".join([f'"{k}": ""' for k in knob_keys]) if knob_keys else "" | |
| ) | |
| lang_entries = ( | |
| ",\n ".join([f'"{k}": ""' for k in lang_keys]) if lang_keys else "" | |
| ) | |
| skeleton = ( | |
| '{\n"knobs": {\n ' | |
| + knob_entries | |
| + "\n},\n" | |
| + '"language_style": {\n ' | |
| + lang_entries | |
| + "\n}\n}" | |
| ) | |
| fields_desc = ( | |
| "Infer these fields (fill values in this JSON skeleton exactly):\n\n" | |
| + skeleton | |
| ) | |
| user = fields_desc + "\n\nTRANSCRIPT (role: text):\n\n" + transcript | |
| messages = [{"role": "user", "content": user}] | |
| return system, messages | |
| def _grade_one( | |
| json_path: Path, | |
| llm: LLMClient, | |
| system_prompt_path: Path, | |
| ) -> Dict[str, Any]: | |
| transcript, actual = _load_transcript_and_actual(json_path) | |
| # Derive fields to grade from actual metadata (names only) | |
| knob_keys = sorted(list((actual.get("knobs") or {}).keys())) | |
| lang_keys = sorted(list((actual.get("language_style") or {}).keys())) | |
| # Fallback: if no keys present, use defaults from persona_knobs.yaml | |
| if not knob_keys: | |
| try: | |
| import yaml # type: ignore | |
| ypath = config.paths.PERSONA_KNOBS_YAML | |
| loaded = yaml.safe_load(ypath.read_text(encoding="utf-8")) | |
| if isinstance(loaded, dict): | |
| knobs_map = loaded.get("knobs") or {} | |
| if isinstance(knobs_map, dict): | |
| knob_keys = sorted([str(k) for k in knobs_map.keys()]) | |
| except Exception: | |
| knob_keys = knob_keys | |
| if not lang_keys: | |
| lang_keys = [ | |
| "language", | |
| "directness", | |
| "formality", | |
| "code_switch_ratio", | |
| "regionalisms", | |
| ] | |
| system, messages = _build_grading_prompt( | |
| transcript, knob_keys, lang_keys, system_prompt_path | |
| ) | |
| resp = llm.get_llm_response_json( | |
| messages=messages, | |
| model=config.models.USER_MODEL, | |
| system_prompt=system, | |
| temperature=config.models.USER_MODEL_TEMPERATURE, | |
| ) | |
| text = (resp.get("text") or "").strip() | |
| inferred: Dict[str, Any] = {} | |
| if text: | |
| # Try tolerant parse via client helper first | |
| parsed = llm.safe_parse_json(text) | |
| if isinstance(parsed, dict): | |
| inferred = parsed | |
| else: | |
| # Try to extract the largest JSON object | |
| try: | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| if start != -1 and end != -1 and end > start: | |
| inferred = json.loads(text[start : end + 1]) | |
| except Exception: | |
| inferred = {} | |
| return { | |
| "path": str(json_path), | |
| "actual": actual, | |
| "inferred": inferred, | |
| } | |
| def _flatten_for_csv( | |
| item: Dict[str, Any], | |
| knobs_order: List[str], | |
| lang_order: List[str], | |
| ) -> Dict[str, str]: | |
| path = Path(item.get("path") or "") | |
| persona_idx = path.stem.split("_")[-1] if path.stem else "" | |
| actual = item.get("actual") or {} | |
| inferred = item.get("inferred") or {} | |
| out: Dict[str, str] = { | |
| "persona_index": persona_idx, | |
| "conversation_base": str(path.with_suffix("").as_posix()), | |
| } | |
| # Pair columns per key: actual then graded | |
| a_knobs = actual.get("knobs") or {} | |
| i_knobs = inferred.get("knobs") or {} | |
| for k in knobs_order: | |
| out[f"actual.knobs.{k}"] = str(a_knobs.get(k, "")) | |
| out[f"graded.knobs.{k}"] = str(i_knobs.get(k, "")) | |
| a_lang = actual.get("language_style") or {} | |
| i_lang = inferred.get("language_style") or {} | |
| for k in lang_order: | |
| out[f"actual.language_style.{k}"] = str(a_lang.get(k, "")) | |
| out[f"graded.language_style.{k}"] = str(i_lang.get(k, "")) | |
| return out | |
| def run_grader( | |
| input_dir: Path, | |
| out_csv: Path, | |
| system_prompt_path: Path, | |
| ) -> Path: | |
| input_dir.mkdir(parents=True, exist_ok=True) | |
| json_paths = _list_conversation_jsons(input_dir) | |
| if not json_paths: | |
| base = str(input_dir) | |
| msg = f"No persona_*.json found in {base}" | |
| raise FileNotFoundError(msg) | |
| llm = LLMClient() | |
| rows: List[Dict[str, Any]] = [] | |
| # Moderate concurrency | |
| cpu = os.cpu_count() or 2 | |
| max_workers = max(1, min(16, cpu * 2)) | |
| with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| futures = [ | |
| ex.submit(_grade_one, p, llm, system_prompt_path) | |
| for p in json_paths | |
| ] | |
| for fut in as_completed(futures): | |
| try: | |
| res = fut.result() | |
| if isinstance(res, dict): | |
| rows.append(res) | |
| except Exception: | |
| continue | |
| # Build union key orders for paired columns | |
| knobs_union: List[str] = [] | |
| lang_union: List[str] = [] | |
| seen_knobs: set = set() | |
| seen_lang: set = set() | |
| for r in rows: | |
| a = r.get("actual") or {} | |
| i = r.get("inferred") or {} | |
| a_kn = (a.get("knobs") or {}).keys() | |
| i_kn = (i.get("knobs") or {}).keys() | |
| for k in sorted(set(list(a_kn) + list(i_kn))): | |
| if k not in seen_knobs: | |
| seen_knobs.add(k) | |
| knobs_union.append(k) | |
| a_ls = (a.get("language_style") or {}).keys() | |
| i_ls = (i.get("language_style") or {}).keys() | |
| for k in sorted(set(list(a_ls) + list(i_ls))): | |
| if k not in seen_lang: | |
| seen_lang.add(k) | |
| lang_union.append(k) | |
| flat_rows: List[Dict[str, str]] = [ | |
| _flatten_for_csv(r, knobs_union, lang_union) for r in rows | |
| ] | |
| # Deterministic header order with pairs | |
| headers: List[str] = [ | |
| "persona_index", | |
| "conversation_base", | |
| ] | |
| for k in knobs_union: | |
| headers.append(f"actual.knobs.{k}") | |
| headers.append(f"graded.knobs.{k}") | |
| for k in lang_union: | |
| headers.append(f"actual.language_style.{k}") | |
| headers.append(f"graded.language_style.{k}") | |
| out_csv.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out_csv, "w", encoding="utf-8", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=headers) | |
| writer.writeheader() | |
| for r in sorted( | |
| flat_rows, key=lambda x: int(x.get("persona_index", "0") or 0) | |
| ): | |
| writer.writerow(r) | |
| return out_csv | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description=( | |
| "Grade user knobs from transcripts only and compare to actual." | |
| ) | |
| ) | |
| p.add_argument( | |
| "--input", | |
| type=Path, | |
| default=_default_input_dir(), | |
| help="Input directory containing persona_*.json", | |
| ) | |
| p.add_argument( | |
| "--out", | |
| type=Path, | |
| default=None, | |
| help="Output CSV path (defaults to graded_knobs_report.csv)", | |
| ) | |
| p.add_argument( | |
| "--system", | |
| type=Path, | |
| default=_default_system_path(), | |
| help="System prompt file for grading", | |
| ) | |
| return p.parse_args() | |
| if __name__ == "__main__": | |
| args = _parse_args() | |
| out_csv = args.out or _default_output_csv(args.input) | |
| path = run_grader(args.input, out_csv, args.system) | |
| print(str(path)) | |