Spaces:
Running
Running
| """W1.b — self-consistency + verifier re-ranking on the WS1 hospital gate. | |
| Samples N raw plans at temperature from the fine-tuned LOCAL planner through the exact | |
| WS1 capture composition (make_batched_planner batch_size=4, no grounded wrapper, no | |
| union — matches eval/results/v6_hospital_raw_plan.json and the Modal capture; see | |
| eval/capture_plan_local.py), majority-votes mappings at (column, raw->canon) cell-edit | |
| level (keep entries in >= ceil(N/2) samples; vote share recorded), then runs the voted | |
| plan through the SHIPPED selective-prediction pipeline — verify_plan(tau) + union with | |
| the grounded heuristic — and scores against hospital's 509 real errors with the | |
| eval.precision_curve machinery. Also captures one greedy (temperature 0) plan as the | |
| reproduction anchor vs the shipped 0.905 @ 0.413. Measurement, not a ship decision. | |
| Decoding is format=json (grammar-constrained): without it the Q8 GGUF's first token | |
| degenerates into <tool_call> loops — the Modal bf16 captures suppressed the same two | |
| tokens (suppress_tokens=[151657, 151658]); this is the local equivalent. | |
| ollama create scrubdata-ft -f notebooks/Modelfile | |
| uv run python -m eval.sc_rerank --model scrubdata-ft --n 8 \ | |
| --out eval/results/sc_rerank.json | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import time | |
| from collections import Counter | |
| from scrubdata.executor import apply_plan | |
| from scrubdata.model_planner import _extract_json, make_batched_planner | |
| from scrubdata.planner import mock_plan | |
| from scrubdata.profiler import profile_dataframe | |
| from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt | |
| from scrubdata.verifier import union_plans, verify_plan | |
| from .precision_curve import TAUS, _repairs_only | |
| from .run_real import _ensure_data, _load | |
| from .run_real_multi import score as _cn_score | |
| SHIPPED = {"precision": 0.905, "coverage": 0.413, "tau": 0.5} # WS1 gate (aa48108) | |
| NUM_PREDICT = 4000 # batch 3 needs 2122 tokens; 2000 truncated 2/5 hospital batches | |
| def _salvage_json(text: str) -> dict | None: | |
| """Repair a generation truncated mid-JSON (done_reason=length): cut at the last | |
| structurally complete value and close the open brackets. Q8-local failure mode: | |
| greedy repetition loops inside a mapping never emit the closing brace; the entries | |
| before the loop are valid and (being duplicates) dedupe in the dict.""" | |
| i = text.find("{") | |
| if i == -1: | |
| return None | |
| stack, in_str, esc = [], False, False | |
| cut = None # (pos, closers) at last safe point | |
| for j, ch in enumerate(text[i:], start=i): | |
| if in_str: | |
| if esc: | |
| esc = False | |
| elif ch == "\\": | |
| esc = True | |
| elif ch == '"': | |
| in_str = False | |
| cut = (j, "".join(reversed(stack))) # after a complete string value/key | |
| elif ch == '"': | |
| in_str = True | |
| elif ch in "{[": | |
| stack.append("}" if ch == "{" else "]") | |
| elif ch in "}]": | |
| if not stack: | |
| return None | |
| stack.pop() | |
| cut = (j, "".join(reversed(stack))) | |
| if not stack or cut is None: | |
| return None | |
| frag = text[i:cut[0] + 1] | |
| # a cut after a KEY (`"key"` with no value yet) is invalid — drop the dangling key | |
| for cand in (frag + cut[1], | |
| frag.rsplit(",", 1)[0] + cut[1] if "," in frag else None): | |
| if cand is None: | |
| continue | |
| try: | |
| return json.loads(cand) | |
| except json.JSONDecodeError: | |
| continue | |
| return None | |
| def make_sampling_planner(model: str, temperature: float, seed: int, | |
| host: str = "http://localhost:11434", timeout: int = 600): | |
| """make_local_ollama_planner with temperature + seed exposed, format=json constrained | |
| (blocks the degenerate <tool_call> first token, like the Modal suppress_tokens).""" | |
| import urllib.request | |
| def planner(dirty_df, *_): | |
| profile = profile_dataframe(dirty_df) | |
| user = build_user_prompt(profile, dirty_df) | |
| payload = { | |
| "model": model, "stream": False, "format": "json", | |
| "messages": [{"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user}], | |
| "options": {"temperature": temperature, "seed": seed, | |
| "num_predict": NUM_PREDICT, "num_ctx": 16384}, | |
| } | |
| req = urllib.request.Request( | |
| host + "/api/chat", data=json.dumps(payload).encode(), | |
| headers={"Content-Type": "application/json"}) | |
| out, last_err = None, None | |
| for attempt in range(3): # ride out transient 500s / reloads | |
| try: | |
| with urllib.request.urlopen(req, timeout=timeout) as r: | |
| out = json.loads(r.read())["message"]["content"] | |
| break | |
| except Exception as e: # noqa: BLE001 | |
| last_err = str(e)[:120] | |
| time.sleep(10 * (attempt + 1)) | |
| if out is None: | |
| return {"__error__": last_err} | |
| plan = _extract_json(out) | |
| if plan is None: | |
| plan = _salvage_json(out) | |
| if plan is not None: | |
| plan["_salvaged"] = True | |
| if plan is None: | |
| return {"__error__": "no_json", "raw": out[:200]} | |
| plan.setdefault("table_operations", []) | |
| plan.setdefault("columns", []) | |
| plan.setdefault("flags", []) | |
| return plan | |
| return planner | |
| def capture_raw_plan(model: str, dirty, temperature: float, seed: int, | |
| host: str = "http://localhost:11434") -> tuple[dict, int]: | |
| """The WS1 capture composition: make_batched_planner(model, 4) — no grounded wrapper, | |
| no fallback (failed batches contribute nothing, as in the Modal capture). Returns | |
| (raw plan, n failed batches).""" | |
| raw = make_sampling_planner(model, temperature, seed, host=host) | |
| failed, salvaged = [0], [0] | |
| def counted(df, *_): | |
| p = raw(df) | |
| if not (isinstance(p, dict) and "__error__" not in p): | |
| failed[0] += 1 | |
| elif p.pop("_salvaged", False): | |
| salvaged[0] += 1 | |
| return p | |
| plan = make_batched_planner(counted, batch_size=4)(dirty) | |
| # sampling can emit malformed entries (a bare string in columns/operations): | |
| # drop non-dict items — the executor/verifier contract is dicts only | |
| plan["columns"] = [c for c in plan.get("columns", []) if isinstance(c, dict)] | |
| for c in plan["columns"]: | |
| c["operations"] = [o for o in c.get("operations", []) if isinstance(o, dict)] | |
| plan["flags"] = [f for f in plan.get("flags", []) if isinstance(f, dict)] | |
| plan["_capture"] = {"failed_batches": failed[0], "salvaged_batches": salvaged[0]} | |
| return plan, failed[0] | |
| def _entries(plan: dict): | |
| """Yield (column, raw, canon, grounded?) for every canonicalize mapping entry.""" | |
| for c in plan.get("columns", []): | |
| for o in c.get("operations", []): | |
| if o.get("op") != "canonicalize_categories": | |
| continue | |
| g = "reference taxonomy" in o.get("rationale", "") | |
| for r, cn in o.get("mapping", {}).items(): | |
| yield (c.get("name"), str(r), str(cn), g) | |
| def vote_plans(plans: list[dict], k: int) -> tuple[dict, dict]: | |
| """Majority-vote N raw plans at (column, raw->canon) cell-edit level: keep entries in | |
| >= k samples (grounded entries keep their rationale so the verifier passes them | |
| through, as in the shipped pipeline). Non-canonicalize ops and table ops are voted at | |
| (column, op-name) level. Returns (voted plan, vote diagnostics).""" | |
| n = len(plans) | |
| votes = Counter(e for p in plans for e in set(_entries(p))) | |
| kept = {e: v for e, v in votes.items() if v >= k} | |
| # column ops other than canonicalize, voted at op identity | |
| op_votes = Counter() | |
| op_proto: dict = {} | |
| for p in plans: | |
| seen = set() | |
| for c in p.get("columns", []): | |
| for o in c.get("operations", []): | |
| if o.get("op") == "canonicalize_categories": | |
| continue | |
| key = (c.get("name"), o.get("op")) | |
| if key not in seen: | |
| seen.add(key) | |
| op_votes[key] += 1 | |
| op_proto.setdefault(key, (o, c)) | |
| cols: dict = {} | |
| def _col(plan_col_name, proto_c): | |
| if plan_col_name not in cols: | |
| cols[plan_col_name] = { | |
| "name": plan_col_name, | |
| "detected_semantic_type": proto_c.get("detected_semantic_type", "categorical"), | |
| "issues": list(proto_c.get("issues", [])), "operations": []} | |
| return cols[plan_col_name] | |
| for (cname, _opn), v in sorted(op_votes.items(), key=lambda x: x[0][1] or ""): | |
| if v >= k: | |
| o, proto_c = op_proto[(cname, _opn)] | |
| _col(cname, proto_c)["operations"].append(json.loads(json.dumps(o))) | |
| proto_cols = {c.get("name"): c for p in plans for c in p.get("columns", [])} | |
| by_col: dict = {} | |
| # ascending vote order so on (column, raw) conflicts the higher-vote canon wins | |
| # (only reachable at k=1 / union-of-all; a majority threshold can keep one side only) | |
| for (cname, r, cn, g), v in sorted(kept.items(), key=lambda x: x[1]): | |
| by_col.setdefault((cname, g), {})[r] = cn | |
| for (cname, g), mapping in by_col.items(): | |
| col = _col(cname, proto_cols.get(cname, {})) | |
| col["operations"].append({ | |
| "op": "canonicalize_categories", "mapping": mapping, | |
| "rationale": ("Reconciled to the reference taxonomy (grounded, not " | |
| "free-generated); self-consistency voted." if g else | |
| f"Self-consistency majority vote over {n} samples.")}) | |
| voted = {"dataset_summary": plans[0].get("dataset_summary", ""), | |
| "table_operations": json.loads(json.dumps(plans[0].get("table_operations", []))), | |
| "columns": list(cols.values()), "flags": []} | |
| diag = {"n_samples": n, "threshold": k, "entries_union": len(votes), | |
| "entries_kept": len(kept), | |
| "vote_hist": dict(Counter(votes.values())), | |
| "kept_vote_share": {f"{c}|{r}->{cn}": round(v / n, 3) | |
| for (c, r, cn, _g), v in sorted(kept.items())}} | |
| return voted, diag | |
| def gate_point(dirty, clean, base_plan: dict, tau: float = 0.5, union: bool = True) -> dict: | |
| """One precision-curve point: verify(tau) [-> union heuristic] -> repairs-only score.""" | |
| plan = verify_plan(dirty, base_plan, tau=tau) | |
| if union: | |
| plan = union_plans(plan, mock_plan(dirty)) | |
| cleaned, _ = apply_plan(dirty, _repairs_only(plan)) | |
| m = _cn_score(dirty, clean, cleaned) | |
| return {"tau": tau, "precision": round(m["precision"], 4), "coverage": round(m["recall"], 4), | |
| "changed": m["_changed"], "fixed": m["_fixed"]} | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--model", default="scrubdata-ft") | |
| ap.add_argument("--n", type=int, default=8) | |
| ap.add_argument("--temperature", type=float, default=0.7) | |
| ap.add_argument("--seed", type=int, default=100, help="base sampling seed (seed+i per sample)") | |
| ap.add_argument("--host", default="http://localhost:11434", | |
| help="ollama host (REPAIRED GGUF sha256 ef08cc6c... verified on " | |
| "0.21.2 with format=json; 0.30.7 silently IGNORES format=json " | |
| "for this model — the pre-repair GGUF sha 9caa0b2c degenerated " | |
| "on every runtime)") | |
| ap.add_argument("--out", type=str, default="eval/results/sc_rerank.json") | |
| ap.add_argument("--blob-sha256-prefix", default="", | |
| help="sha256 prefix of the served GGUF blob (provenance)") | |
| args = ap.parse_args() | |
| _ensure_data() | |
| dirty, clean = _load() | |
| # reproduction anchor: greedy capture through the same pipeline | |
| t0 = time.time() | |
| greedy, g_fb = capture_raw_plan(args.model, dirty, 0.0, args.seed, host=args.host) | |
| g_secs = time.time() - t0 | |
| g_point = gate_point(dirty, clean, greedy) | |
| g_cap = greedy.get("_capture", {}) | |
| print(f"[greedy anchor] {g_secs:.0f}s, capture={g_cap}, " | |
| f"union tau=0.5: {g_point['precision']:.3f} @ {g_point['coverage']:.3f} " | |
| f"(shipped {SHIPPED['precision']} @ {SHIPPED['coverage']})", flush=True) | |
| samples = [] | |
| for i in range(args.n): | |
| t0 = time.time() | |
| plan, fb = capture_raw_plan(args.model, dirty, args.temperature, | |
| args.seed + 1 + i, host=args.host) | |
| if fb and not list(_entries(plan)): # server hiccup ate the sample: one redo | |
| print(f"[sample {i + 1}/{args.n}] all batches failed — retrying once", flush=True) | |
| plan, fb = capture_raw_plan(args.model, dirty, args.temperature, | |
| args.seed + 1 + i, host=args.host) | |
| secs = time.time() - t0 | |
| pt = gate_point(dirty, clean, plan) | |
| samples.append({"seed": args.seed + 1 + i, "secs": round(secs, 1), | |
| "capture": plan.get("_capture", {}), | |
| "n_entries": len(set(_entries(plan))), | |
| "plan": plan, "point_tau05_union": pt}) | |
| print(f"[sample {i + 1}/{args.n}] {secs:.0f}s, capture={samples[-1]['capture']}, " | |
| f"entries={samples[-1]['n_entries']}, union tau=0.5: " | |
| f"{pt['precision']:.3f} @ {pt['coverage']:.3f}", flush=True) | |
| json.dump(samples, open(args.out + ".partial", "w")) # checkpoint | |
| k = math.ceil(args.n / 2) | |
| voted, diag = vote_plans([s["plan"] for s in samples], k) | |
| print(f"\n[vote] union {diag['entries_union']} entries -> kept {diag['entries_kept']} " | |
| f"(>= {k}/{args.n} votes); hist {diag['vote_hist']}") | |
| rows = [gate_point(dirty, clean, voted, tau=t) for t in TAUS] | |
| print(f"\n=== voted plan + verify + heuristic union (hospital, 509 real errors) ===") | |
| print(f"{'tau':>5}{'precision':>11}{'coverage':>10}{'changed':>9}{'fixed':>7}") | |
| for r in rows: | |
| print(f"{r['tau']:>5.2f}{r['precision']:>11.3f}{r['coverage']:>10.3f}" | |
| f"{r['changed']:>9}{r['fixed']:>7}") | |
| v_point = next(r for r in rows if r["tau"] == 0.5) | |
| print(f"\nvoted-union @ tau=0.5: {v_point['precision']:.3f} @ {v_point['coverage']:.3f} " | |
| f"vs shipped {SHIPPED['precision']} @ {SHIPPED['coverage']}") | |
| # ablation (a): best single sample — max precision among points with coverage >= 0.30 | |
| eligible = [s for s in samples if s["point_tau05_union"]["coverage"] >= 0.30] | |
| pool = eligible or samples # if nothing reaches 0.30, fall back + flag it | |
| best = max(pool, key=lambda s: s["point_tau05_union"]["precision"]) | |
| best_single = dict(best["point_tau05_union"], | |
| seed=best["seed"], coverage_floor_met=bool(eligible)) | |
| print(f"[ablation a] best single (seed {best['seed']}): " | |
| f"{best_single['precision']:.3f} @ {best_single['coverage']:.3f} " | |
| f"(eligible >=0.30 cov: {len(eligible)}/{len(samples)})") | |
| # ablation (b): union of ALL samples (k=1, conflicts -> higher-vote canon), verify+union | |
| union_plan, union_diag = vote_plans([s["plan"] for s in samples], 1) | |
| u_point = gate_point(dirty, clean, union_plan) | |
| print(f"[ablation b] union-of-all ({union_diag['entries_union']} entries): " | |
| f"{u_point['precision']:.3f} @ {u_point['coverage']:.3f}") | |
| out = {"model": args.model, "n_samples": args.n, "temperature": args.temperature, | |
| "base_seed": args.seed, "host": args.host, "threshold": k, | |
| "decoding": {"temperature": args.temperature, "format": "json", | |
| "num_predict": NUM_PREDICT, "num_ctx": 16384}, | |
| "model_blob_sha256_prefix": args.blob_sha256_prefix or None, | |
| "shipped_reference": SHIPPED, | |
| "greedy_anchor": {"secs": round(g_secs, 1), "failed_batches": g_fb, | |
| "point_tau05_union": g_point}, | |
| "per_sample_runtimes": [s["secs"] for s in samples], | |
| "per_sample": [{kk: v for kk, v in s.items() if kk != "plan"} for s in samples], | |
| "vote": diag, "voted_curve": rows, "voted": v_point, | |
| "best_single": best_single, | |
| "union_all": dict(u_point, entries_union=union_diag["entries_union"]), | |
| "voted_plan": voted} | |
| json.dump(out, open(args.out, "w"), indent=1) | |
| print(f"results written to {args.out}") | |
| if __name__ == "__main__": | |
| main() | |