scrubdata / eval /diagnose_model.py
OpenAI Codex
deploy: add sponsor:openai tag (Best Use of Codex) + Codex-hardened build
16dc556
Raw
History Blame Contribute Delete
3.04 kB
"""Diagnose vanilla-model failures: truncation vs genuine schema violation.
Runs N examples through an Ollama Cloud model, categorizing each output:
empty / no_json / truncated / json_but_schema_invalid / valid
and reading `oll`'s stderr token counts to detect output hitting the cap.
uv run eval/diagnose_model.py --n 12 --model glm-5.1 --max-tokens 8000
"""
from __future__ import annotations
import argparse
import json
import random
import re
import subprocess
from collections import Counter
from jsonschema import Draft202012Validator
from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt
from scrubdata.profiler import profile_dataframe
from training.generate import make_example
from .metrics import PLAN_SCHEMA
_V = Draft202012Validator(PLAN_SCHEMA)
_TOK = re.compile(r"out\s+(\d+)\s*tok", re.I)
def _call(user: str, model: str, max_tokens: int):
r = subprocess.run(
["oll", "--model", model, "--system", SYSTEM_PROMPT,
"--max-tokens", str(max_tokens), "--temperature", "0"],
input=user, capture_output=True, text=True, timeout=300)
out_tok = None
m = _TOK.search(r.stderr or "")
if m:
out_tok = int(m.group(1))
return r.stdout, out_tok
def _categorize(out: str, out_tok: int | None, max_tokens: int):
s = out.strip()
if not s:
return "empty", None
i, j = s.find("{"), s.rfind("}")
if i == -1:
return "no_json", None
near_cap = out_tok is not None and out_tok >= max_tokens - 50
if j < i:
return ("truncated" if near_cap else "no_close_brace"), None
try:
plan = json.loads(s[i:j + 1])
except json.JSONDecodeError:
return ("truncated" if near_cap else "malformed_json"), None
errs = sorted(_V.iter_errors(plan), key=lambda e: e.path)
if not errs:
return "valid", None
return "schema_invalid", errs[0].message[:90]
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--n", type=int, default=12)
ap.add_argument("--model", type=str, default="glm-5.1")
ap.add_argument("--max-tokens", type=int, default=8000)
ap.add_argument("--seed", type=int, default=4242)
args = ap.parse_args()
rng = random.Random(args.seed)
cats = Counter()
print(f"Diagnosing {args.model} @ max_tokens={args.max_tokens} on {args.n} examples\n")
for k in range(args.n):
ex = make_example(rng)
user = build_user_prompt(profile_dataframe(ex["dirty_df"]), ex["dirty_df"])
out, out_tok = _call(user, args.model, args.max_tokens)
cat, detail = _categorize(out, out_tok, args.max_tokens)
cats[cat] += 1
print(f" ex{k:2d}: {cat:<16} out_tok={out_tok}"
+ (f" [{detail}]" if detail else ""))
print("\nBreakdown:", dict(cats))
valid = cats.get("valid", 0)
trunc = cats.get("truncated", 0)
print(f"valid={valid}/{args.n} ({valid/args.n:.0%}) | truncated={trunc} "
f"| schema_invalid={cats.get('schema_invalid', 0)}")
if __name__ == "__main__":
main()