scrubdata / eval /sc_rerank.py
OpenAI Codex
deploy: add sponsor:openai tag (Best Use of Codex) + Codex-hardened build
16dc556
Raw
History Blame Contribute Delete
16.6 kB
"""W1.b — self-consistency + verifier re-ranking on the WS1 hospital gate.
Samples N raw plans at temperature from the fine-tuned LOCAL planner through the exact
WS1 capture composition (make_batched_planner batch_size=4, no grounded wrapper, no
union — matches eval/results/v6_hospital_raw_plan.json and the Modal capture; see
eval/capture_plan_local.py), majority-votes mappings at (column, raw->canon) cell-edit
level (keep entries in >= ceil(N/2) samples; vote share recorded), then runs the voted
plan through the SHIPPED selective-prediction pipeline — verify_plan(tau) + union with
the grounded heuristic — and scores against hospital's 509 real errors with the
eval.precision_curve machinery. Also captures one greedy (temperature 0) plan as the
reproduction anchor vs the shipped 0.905 @ 0.413. Measurement, not a ship decision.
Decoding is format=json (grammar-constrained): without it the Q8 GGUF's first token
degenerates into <tool_call> loops — the Modal bf16 captures suppressed the same two
tokens (suppress_tokens=[151657, 151658]); this is the local equivalent.
ollama create scrubdata-ft -f notebooks/Modelfile
uv run python -m eval.sc_rerank --model scrubdata-ft --n 8 \
--out eval/results/sc_rerank.json
"""
from __future__ import annotations
import argparse
import json
import math
import time
from collections import Counter
from scrubdata.executor import apply_plan
from scrubdata.model_planner import _extract_json, make_batched_planner
from scrubdata.planner import mock_plan
from scrubdata.profiler import profile_dataframe
from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt
from scrubdata.verifier import union_plans, verify_plan
from .precision_curve import TAUS, _repairs_only
from .run_real import _ensure_data, _load
from .run_real_multi import score as _cn_score
SHIPPED = {"precision": 0.905, "coverage": 0.413, "tau": 0.5} # WS1 gate (aa48108)
NUM_PREDICT = 4000 # batch 3 needs 2122 tokens; 2000 truncated 2/5 hospital batches
def _salvage_json(text: str) -> dict | None:
"""Repair a generation truncated mid-JSON (done_reason=length): cut at the last
structurally complete value and close the open brackets. Q8-local failure mode:
greedy repetition loops inside a mapping never emit the closing brace; the entries
before the loop are valid and (being duplicates) dedupe in the dict."""
i = text.find("{")
if i == -1:
return None
stack, in_str, esc = [], False, False
cut = None # (pos, closers) at last safe point
for j, ch in enumerate(text[i:], start=i):
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
cut = (j, "".join(reversed(stack))) # after a complete string value/key
elif ch == '"':
in_str = True
elif ch in "{[":
stack.append("}" if ch == "{" else "]")
elif ch in "}]":
if not stack:
return None
stack.pop()
cut = (j, "".join(reversed(stack)))
if not stack or cut is None:
return None
frag = text[i:cut[0] + 1]
# a cut after a KEY (`"key"` with no value yet) is invalid — drop the dangling key
for cand in (frag + cut[1],
frag.rsplit(",", 1)[0] + cut[1] if "," in frag else None):
if cand is None:
continue
try:
return json.loads(cand)
except json.JSONDecodeError:
continue
return None
def make_sampling_planner(model: str, temperature: float, seed: int,
host: str = "http://localhost:11434", timeout: int = 600):
"""make_local_ollama_planner with temperature + seed exposed, format=json constrained
(blocks the degenerate <tool_call> first token, like the Modal suppress_tokens)."""
import urllib.request
def planner(dirty_df, *_):
profile = profile_dataframe(dirty_df)
user = build_user_prompt(profile, dirty_df)
payload = {
"model": model, "stream": False, "format": "json",
"messages": [{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user}],
"options": {"temperature": temperature, "seed": seed,
"num_predict": NUM_PREDICT, "num_ctx": 16384},
}
req = urllib.request.Request(
host + "/api/chat", data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"})
out, last_err = None, None
for attempt in range(3): # ride out transient 500s / reloads
try:
with urllib.request.urlopen(req, timeout=timeout) as r:
out = json.loads(r.read())["message"]["content"]
break
except Exception as e: # noqa: BLE001
last_err = str(e)[:120]
time.sleep(10 * (attempt + 1))
if out is None:
return {"__error__": last_err}
plan = _extract_json(out)
if plan is None:
plan = _salvage_json(out)
if plan is not None:
plan["_salvaged"] = True
if plan is None:
return {"__error__": "no_json", "raw": out[:200]}
plan.setdefault("table_operations", [])
plan.setdefault("columns", [])
plan.setdefault("flags", [])
return plan
return planner
def capture_raw_plan(model: str, dirty, temperature: float, seed: int,
host: str = "http://localhost:11434") -> tuple[dict, int]:
"""The WS1 capture composition: make_batched_planner(model, 4) — no grounded wrapper,
no fallback (failed batches contribute nothing, as in the Modal capture). Returns
(raw plan, n failed batches)."""
raw = make_sampling_planner(model, temperature, seed, host=host)
failed, salvaged = [0], [0]
def counted(df, *_):
p = raw(df)
if not (isinstance(p, dict) and "__error__" not in p):
failed[0] += 1
elif p.pop("_salvaged", False):
salvaged[0] += 1
return p
plan = make_batched_planner(counted, batch_size=4)(dirty)
# sampling can emit malformed entries (a bare string in columns/operations):
# drop non-dict items — the executor/verifier contract is dicts only
plan["columns"] = [c for c in plan.get("columns", []) if isinstance(c, dict)]
for c in plan["columns"]:
c["operations"] = [o for o in c.get("operations", []) if isinstance(o, dict)]
plan["flags"] = [f for f in plan.get("flags", []) if isinstance(f, dict)]
plan["_capture"] = {"failed_batches": failed[0], "salvaged_batches": salvaged[0]}
return plan, failed[0]
def _entries(plan: dict):
"""Yield (column, raw, canon, grounded?) for every canonicalize mapping entry."""
for c in plan.get("columns", []):
for o in c.get("operations", []):
if o.get("op") != "canonicalize_categories":
continue
g = "reference taxonomy" in o.get("rationale", "")
for r, cn in o.get("mapping", {}).items():
yield (c.get("name"), str(r), str(cn), g)
def vote_plans(plans: list[dict], k: int) -> tuple[dict, dict]:
"""Majority-vote N raw plans at (column, raw->canon) cell-edit level: keep entries in
>= k samples (grounded entries keep their rationale so the verifier passes them
through, as in the shipped pipeline). Non-canonicalize ops and table ops are voted at
(column, op-name) level. Returns (voted plan, vote diagnostics)."""
n = len(plans)
votes = Counter(e for p in plans for e in set(_entries(p)))
kept = {e: v for e, v in votes.items() if v >= k}
# column ops other than canonicalize, voted at op identity
op_votes = Counter()
op_proto: dict = {}
for p in plans:
seen = set()
for c in p.get("columns", []):
for o in c.get("operations", []):
if o.get("op") == "canonicalize_categories":
continue
key = (c.get("name"), o.get("op"))
if key not in seen:
seen.add(key)
op_votes[key] += 1
op_proto.setdefault(key, (o, c))
cols: dict = {}
def _col(plan_col_name, proto_c):
if plan_col_name not in cols:
cols[plan_col_name] = {
"name": plan_col_name,
"detected_semantic_type": proto_c.get("detected_semantic_type", "categorical"),
"issues": list(proto_c.get("issues", [])), "operations": []}
return cols[plan_col_name]
for (cname, _opn), v in sorted(op_votes.items(), key=lambda x: x[0][1] or ""):
if v >= k:
o, proto_c = op_proto[(cname, _opn)]
_col(cname, proto_c)["operations"].append(json.loads(json.dumps(o)))
proto_cols = {c.get("name"): c for p in plans for c in p.get("columns", [])}
by_col: dict = {}
# ascending vote order so on (column, raw) conflicts the higher-vote canon wins
# (only reachable at k=1 / union-of-all; a majority threshold can keep one side only)
for (cname, r, cn, g), v in sorted(kept.items(), key=lambda x: x[1]):
by_col.setdefault((cname, g), {})[r] = cn
for (cname, g), mapping in by_col.items():
col = _col(cname, proto_cols.get(cname, {}))
col["operations"].append({
"op": "canonicalize_categories", "mapping": mapping,
"rationale": ("Reconciled to the reference taxonomy (grounded, not "
"free-generated); self-consistency voted." if g else
f"Self-consistency majority vote over {n} samples.")})
voted = {"dataset_summary": plans[0].get("dataset_summary", ""),
"table_operations": json.loads(json.dumps(plans[0].get("table_operations", []))),
"columns": list(cols.values()), "flags": []}
diag = {"n_samples": n, "threshold": k, "entries_union": len(votes),
"entries_kept": len(kept),
"vote_hist": dict(Counter(votes.values())),
"kept_vote_share": {f"{c}|{r}->{cn}": round(v / n, 3)
for (c, r, cn, _g), v in sorted(kept.items())}}
return voted, diag
def gate_point(dirty, clean, base_plan: dict, tau: float = 0.5, union: bool = True) -> dict:
"""One precision-curve point: verify(tau) [-> union heuristic] -> repairs-only score."""
plan = verify_plan(dirty, base_plan, tau=tau)
if union:
plan = union_plans(plan, mock_plan(dirty))
cleaned, _ = apply_plan(dirty, _repairs_only(plan))
m = _cn_score(dirty, clean, cleaned)
return {"tau": tau, "precision": round(m["precision"], 4), "coverage": round(m["recall"], 4),
"changed": m["_changed"], "fixed": m["_fixed"]}
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--model", default="scrubdata-ft")
ap.add_argument("--n", type=int, default=8)
ap.add_argument("--temperature", type=float, default=0.7)
ap.add_argument("--seed", type=int, default=100, help="base sampling seed (seed+i per sample)")
ap.add_argument("--host", default="http://localhost:11434",
help="ollama host (REPAIRED GGUF sha256 ef08cc6c... verified on "
"0.21.2 with format=json; 0.30.7 silently IGNORES format=json "
"for this model — the pre-repair GGUF sha 9caa0b2c degenerated "
"on every runtime)")
ap.add_argument("--out", type=str, default="eval/results/sc_rerank.json")
ap.add_argument("--blob-sha256-prefix", default="",
help="sha256 prefix of the served GGUF blob (provenance)")
args = ap.parse_args()
_ensure_data()
dirty, clean = _load()
# reproduction anchor: greedy capture through the same pipeline
t0 = time.time()
greedy, g_fb = capture_raw_plan(args.model, dirty, 0.0, args.seed, host=args.host)
g_secs = time.time() - t0
g_point = gate_point(dirty, clean, greedy)
g_cap = greedy.get("_capture", {})
print(f"[greedy anchor] {g_secs:.0f}s, capture={g_cap}, "
f"union tau=0.5: {g_point['precision']:.3f} @ {g_point['coverage']:.3f} "
f"(shipped {SHIPPED['precision']} @ {SHIPPED['coverage']})", flush=True)
samples = []
for i in range(args.n):
t0 = time.time()
plan, fb = capture_raw_plan(args.model, dirty, args.temperature,
args.seed + 1 + i, host=args.host)
if fb and not list(_entries(plan)): # server hiccup ate the sample: one redo
print(f"[sample {i + 1}/{args.n}] all batches failed — retrying once", flush=True)
plan, fb = capture_raw_plan(args.model, dirty, args.temperature,
args.seed + 1 + i, host=args.host)
secs = time.time() - t0
pt = gate_point(dirty, clean, plan)
samples.append({"seed": args.seed + 1 + i, "secs": round(secs, 1),
"capture": plan.get("_capture", {}),
"n_entries": len(set(_entries(plan))),
"plan": plan, "point_tau05_union": pt})
print(f"[sample {i + 1}/{args.n}] {secs:.0f}s, capture={samples[-1]['capture']}, "
f"entries={samples[-1]['n_entries']}, union tau=0.5: "
f"{pt['precision']:.3f} @ {pt['coverage']:.3f}", flush=True)
json.dump(samples, open(args.out + ".partial", "w")) # checkpoint
k = math.ceil(args.n / 2)
voted, diag = vote_plans([s["plan"] for s in samples], k)
print(f"\n[vote] union {diag['entries_union']} entries -> kept {diag['entries_kept']} "
f"(>= {k}/{args.n} votes); hist {diag['vote_hist']}")
rows = [gate_point(dirty, clean, voted, tau=t) for t in TAUS]
print(f"\n=== voted plan + verify + heuristic union (hospital, 509 real errors) ===")
print(f"{'tau':>5}{'precision':>11}{'coverage':>10}{'changed':>9}{'fixed':>7}")
for r in rows:
print(f"{r['tau']:>5.2f}{r['precision']:>11.3f}{r['coverage']:>10.3f}"
f"{r['changed']:>9}{r['fixed']:>7}")
v_point = next(r for r in rows if r["tau"] == 0.5)
print(f"\nvoted-union @ tau=0.5: {v_point['precision']:.3f} @ {v_point['coverage']:.3f} "
f"vs shipped {SHIPPED['precision']} @ {SHIPPED['coverage']}")
# ablation (a): best single sample — max precision among points with coverage >= 0.30
eligible = [s for s in samples if s["point_tau05_union"]["coverage"] >= 0.30]
pool = eligible or samples # if nothing reaches 0.30, fall back + flag it
best = max(pool, key=lambda s: s["point_tau05_union"]["precision"])
best_single = dict(best["point_tau05_union"],
seed=best["seed"], coverage_floor_met=bool(eligible))
print(f"[ablation a] best single (seed {best['seed']}): "
f"{best_single['precision']:.3f} @ {best_single['coverage']:.3f} "
f"(eligible >=0.30 cov: {len(eligible)}/{len(samples)})")
# ablation (b): union of ALL samples (k=1, conflicts -> higher-vote canon), verify+union
union_plan, union_diag = vote_plans([s["plan"] for s in samples], 1)
u_point = gate_point(dirty, clean, union_plan)
print(f"[ablation b] union-of-all ({union_diag['entries_union']} entries): "
f"{u_point['precision']:.3f} @ {u_point['coverage']:.3f}")
out = {"model": args.model, "n_samples": args.n, "temperature": args.temperature,
"base_seed": args.seed, "host": args.host, "threshold": k,
"decoding": {"temperature": args.temperature, "format": "json",
"num_predict": NUM_PREDICT, "num_ctx": 16384},
"model_blob_sha256_prefix": args.blob_sha256_prefix or None,
"shipped_reference": SHIPPED,
"greedy_anchor": {"secs": round(g_secs, 1), "failed_batches": g_fb,
"point_tau05_union": g_point},
"per_sample_runtimes": [s["secs"] for s in samples],
"per_sample": [{kk: v for kk, v in s.items() if kk != "plan"} for s in samples],
"vote": diag, "voted_curve": rows, "voted": v_point,
"best_single": best_single,
"union_all": dict(u_point, entries_union=union_diag["entries_union"]),
"voted_plan": voted}
json.dump(out, open(args.out, "w"), indent=1)
print(f"results written to {args.out}")
if __name__ == "__main__":
main()