Spaces:
Running
Running
| """Diagnose vanilla-model failures: truncation vs genuine schema violation. | |
| Runs N examples through an Ollama Cloud model, categorizing each output: | |
| empty / no_json / truncated / json_but_schema_invalid / valid | |
| and reading `oll`'s stderr token counts to detect output hitting the cap. | |
| uv run eval/diagnose_model.py --n 12 --model glm-5.1 --max-tokens 8000 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import re | |
| import subprocess | |
| from collections import Counter | |
| from jsonschema import Draft202012Validator | |
| from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt | |
| from scrubdata.profiler import profile_dataframe | |
| from training.generate import make_example | |
| from .metrics import PLAN_SCHEMA | |
| _V = Draft202012Validator(PLAN_SCHEMA) | |
| _TOK = re.compile(r"out\s+(\d+)\s*tok", re.I) | |
| def _call(user: str, model: str, max_tokens: int): | |
| r = subprocess.run( | |
| ["oll", "--model", model, "--system", SYSTEM_PROMPT, | |
| "--max-tokens", str(max_tokens), "--temperature", "0"], | |
| input=user, capture_output=True, text=True, timeout=300) | |
| out_tok = None | |
| m = _TOK.search(r.stderr or "") | |
| if m: | |
| out_tok = int(m.group(1)) | |
| return r.stdout, out_tok | |
| def _categorize(out: str, out_tok: int | None, max_tokens: int): | |
| s = out.strip() | |
| if not s: | |
| return "empty", None | |
| i, j = s.find("{"), s.rfind("}") | |
| if i == -1: | |
| return "no_json", None | |
| near_cap = out_tok is not None and out_tok >= max_tokens - 50 | |
| if j < i: | |
| return ("truncated" if near_cap else "no_close_brace"), None | |
| try: | |
| plan = json.loads(s[i:j + 1]) | |
| except json.JSONDecodeError: | |
| return ("truncated" if near_cap else "malformed_json"), None | |
| errs = sorted(_V.iter_errors(plan), key=lambda e: e.path) | |
| if not errs: | |
| return "valid", None | |
| return "schema_invalid", errs[0].message[:90] | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--n", type=int, default=12) | |
| ap.add_argument("--model", type=str, default="glm-5.1") | |
| ap.add_argument("--max-tokens", type=int, default=8000) | |
| ap.add_argument("--seed", type=int, default=4242) | |
| args = ap.parse_args() | |
| rng = random.Random(args.seed) | |
| cats = Counter() | |
| print(f"Diagnosing {args.model} @ max_tokens={args.max_tokens} on {args.n} examples\n") | |
| for k in range(args.n): | |
| ex = make_example(rng) | |
| user = build_user_prompt(profile_dataframe(ex["dirty_df"]), ex["dirty_df"]) | |
| out, out_tok = _call(user, args.model, args.max_tokens) | |
| cat, detail = _categorize(out, out_tok, args.max_tokens) | |
| cats[cat] += 1 | |
| print(f" ex{k:2d}: {cat:<16} out_tok={out_tok}" | |
| + (f" [{detail}]" if detail else "")) | |
| print("\nBreakdown:", dict(cats)) | |
| valid = cats.get("valid", 0) | |
| trunc = cats.get("truncated", 0) | |
| print(f"valid={valid}/{args.n} ({valid/args.n:.0%}) | truncated={trunc} " | |
| f"| schema_invalid={cats.get('schema_invalid', 0)}") | |
| if __name__ == "__main__": | |
| main() | |