scrubdata / eval /precision_curve.py
OpenAI Codex
deploy: add sponsor:openai tag (Best Use of Codex) + Codex-hardened build
16dc556
Raw
History Blame Contribute Delete
4.55 kB
"""WS1 deliverable: precision-coverage curve for the verified planner on real errors.
Sweeps the verifier threshold tau over the hospital benchmark and reports, per tau:
precision = repair_prec (of the cells we changed, how many match gold)
coverage = repair_recall (of the real errors, how many we fixed)
GATE (publication plan): precision >= 0.70 at coverage >= 0.30. The verified planner
abstains on low-confidence merges instead of committing them — selective prediction at
the plan level, contract-preserving (dropped entries become review flags).
uv run python -m eval.precision_curve # grounded heuristic planner
uv run python -m eval.precision_curve --plan plan.json # pre-captured model plan
uv run python -m eval.precision_curve --plan plan.json --union # production pipeline
"""
from __future__ import annotations
import argparse
import json
from scrubdata.executor import apply_plan
from scrubdata.planner import mock_plan
from scrubdata.verifier import union_plans, verify_plan
from .run_real import _ensure_data, _load
from .run_real_multi import score as _cn_score # churn-neutral scoring
TAUS = [0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
def _repairs_only(plan: dict) -> dict:
"""Keep only the REPAIR decisions (canonicalize mappings); drop format/table ops.
This is the Baran-comparable protocol: precision over error-repair decisions,
not over convention standardization (dates->ISO etc., which the raw benchmark
stores as text and would flood the denominator)."""
import copy
out = copy.deepcopy(plan)
out["table_operations"] = []
for c in out.get("columns", []):
c["operations"] = [o for o in c.get("operations", [])
if o.get("op") == "canonicalize_categories"]
out["columns"] = [c for c in out.get("columns", []) if c.get("operations")]
return out
def curve(dirty, clean, base_plan: dict, label: str, union: bool = False) -> list[dict]:
rows = []
heuristic = mock_plan(dirty) if union else None
print(f"\n=== precision-coverage: {label} (hospital, 509 real errors) ===")
print(f"{'tau':>5}{'precision':>11}{'coverage':>10}{'changed':>9}{'fixed':>7}")
print("-" * 44)
for tau in TAUS:
plan = verify_plan(dirty, base_plan, tau=tau)
if union: # the production (active.py) composition
plan = union_plans(plan, heuristic)
plan = _repairs_only(plan)
cleaned, _ = apply_plan(dirty, plan)
m = _cn_score(dirty, clean, cleaned)
rows.append({"tau": tau, "precision": m["precision"], "coverage": m["recall"],
"changed": m["_changed"], "fixed": m["_fixed"]})
gate = " <-- GATE" if m["precision"] >= 0.70 and m["recall"] >= 0.30 else ""
print(f"{tau:>5.2f}{m['precision']:>11.3f}{m['recall']:>10.3f}"
f"{m['_changed']:>9}{m['_fixed']:>7}{gate}")
ok = [r for r in rows if r["precision"] >= 0.70 and r["coverage"] >= 0.30]
best = max(ok, key=lambda r: r["coverage"]) if ok else max(rows, key=lambda r: (r["precision"] >= 0.70) * r["coverage"])
if ok:
print(f"\nGATE: PASS at tau={best['tau']} (precision {best['precision']:.3f}, "
f"coverage {best['coverage']:.3f})")
else:
hi = max(rows, key=lambda r: r["precision"])
print(f"\nGATE: not cleared — max precision {hi['precision']:.3f} at "
f"coverage {hi['coverage']:.3f} (tau={hi['tau']})")
return rows
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--plan", type=str, default=None,
help="path to a captured raw plan JSON (e.g. the v6 model's)")
ap.add_argument("--union", action="store_true",
help="union each verified plan with the grounded heuristic "
"(the shipped active.py pipeline)")
ap.add_argument("--out", type=str, default=None, help="write curve rows to JSON")
args = ap.parse_args()
_ensure_data()
dirty, clean = _load()
if args.plan:
base_plan = json.load(open(args.plan))
label = f"model plan ({args.plan})" + (" + heuristic union" if args.union else "")
else:
base_plan = mock_plan(dirty)
label = "grounded heuristic"
rows = curve(dirty, clean, base_plan, label, union=args.union)
if args.out:
json.dump(rows, open(args.out, "w"), indent=1)
print(f"curve written to {args.out}")
if __name__ == "__main__":
main()