"""WS4 learned-repair baselines: scoring + Jellyfish prompt construction. Both baselines bypass plan dicts (the executor is column-level by design; learned repair is per-cell) — they produce repaired DataFrames scored by the SAME churn-neutral `eval.run_real_multi.score` as every other row of the money table. * Baran: repaired CSVs come from eval/run_baran.py (pinned env). Score here: uv run python -m eval.baselines_learned --score-baran * Jellyfish: prompts built here (unit-testable without a GPU), executed by scripts/modal_jellyfish.py (vLLM on Modal), scored in-run with the same `score`. Jellyfish has NO repair task — we compose its two published cell-level tasks: error detection (yes/no per cell) then data imputation (infer the flagged cell with the attribute removed). Prompt templates are verbatim from the NECOUDBFM/Jellyfish-13B model card; this composition is OURS, not theirs (disclosed in the paper). """ from __future__ import annotations import argparse import json from pathlib import Path SYSTEM_MESSAGE = ("You are an AI assistant that follows instruction extremely well. " "Help as much as you can.") _ED_TEMPLATE = ( "Your task is to determine if there is an error in the value of a specific " "attribute within the whole record provided.\n" "The attributes may include {attrs}.\n" "Errors may include, but are not limited to, spelling errors, inconsistencies, " "or values that don't make sense given the context of the whole record.\n" "Record [{record}]\n" "Attribute for Verification: [{col}: {val}]\n" "Question: Is there an error in the value of {col}? " "Choose your answer from: [Yes, No]." ) _DI_TEMPLATE = ( "You are presented with a {keyword} record that is missing a specific attribute: " "{col}.\n" "Your task is to deduce or infer the value of {col} using the available " "information in the record.\n" "You may be provided with fields like {attrs} to help you in the inference.\n" "Record: [{record}]\n" "Based on the provided record, what would you infer is the value for the missing " "attribute {col}?\n" "Answer only the value of {col}." ) def wrap_prompt(user_message: str) -> str: """The Jellyfish-13B chat scaffold (verbatim from the model card).""" return f"{SYSTEM_MESSAGE}\n\n### Instruction:\n\n{user_message}\n\n### Response:\n\n" def _serialize(record: dict, skip: str | None = None) -> str: return ", ".join(f"{k}: {v}" for k, v in record.items() if k != skip) def ed_prompt(record: dict, col: str) -> str: """Error-detection prompt (whole-record form) for one cell.""" return wrap_prompt(_ED_TEMPLATE.format( attrs=", ".join(record.keys()), record=_serialize(record), col=col, val=record[col])) def di_prompt(record: dict, col: str, keyword: str) -> str: """Data-imputation prompt for a flagged cell — the attribute is REMOVED from the serialized record so the model infers, not copies.""" attrs = [k for k in record.keys() if k != col] return wrap_prompt(_DI_TEMPLATE.format( keyword=keyword, col=col, attrs=", ".join(attrs), record=_serialize(record, skip=col))) def parse_ed(text: str) -> bool: """True = the model says the cell is erroneous.""" return text.strip().lower().lstrip("[").startswith("yes") def parse_di(text: str, original: str) -> str: """Imputed value, or the original (abstain) when the answer is unusable — empty, multi-line/rambling, or implausibly long for a cell.""" ans = text.strip().strip('"').strip() if not ans or "\n" in ans or len(ans) > 80: return original return ans # ---------------------------------------------------------------- Baran scoring def score_baran(repaired_dir: str = "eval/results/baran", out: str = "eval/results/baran_raha.json") -> dict: """Score every _seed_repaired.csv against (dirty, clean) under the identical churn-neutral protocol; macro REAL-F1 mean ± 95% CI over seeds.""" import collections import pandas as pd from .run_real_multi import _raha_pair, score per_seed: dict[int, list] = collections.defaultdict(list) per_ds = [] for p in sorted(Path(repaired_dir).glob("*_seed*_repaired.csv")): name, seed = p.stem.rsplit("_repaired", 1)[0].rsplit("_seed", 1) repaired = pd.read_csv(p, dtype=str, keep_default_na=False) dirty, clean = _raha_pair(name) m = score(dirty, clean, repaired) per_seed[int(seed)].append(m) per_ds.append({"name": name, "seed": int(seed), **{k: v for k, v in m.items()}}) print(f" {name:<10} seed{seed}: F1={m['f1']:.3f} P={m['precision']:.3f} " f"R={m['recall']:.3f} dmg={m['damage']:.3f}") if not per_seed: raise SystemExit(f"no repaired CSVs found in {repaired_dir}") def mean(xs): xs = list(xs) return sum(xs) / len(xs) if xs else 0.0 seed_f1 = [mean(m["f1"] for m in ms) for ms in per_seed.values()] mu = mean(seed_f1) var = mean([(x - mu) ** 2 for x in seed_f1]) ci = 1.96 * (var ** 0.5) / (len(seed_f1) ** 0.5) result = { "system": "Baran (oracle detection, 20 gold labels)", "real_f1": mu, "real_f1_ci": ci, "real_f1_per_seed": seed_f1, "damage": mean(mean(m["damage"] for m in ms) for ms in per_seed.values()), "precision": mean(mean(m["precision"] for m in ms) for ms in per_seed.values()), "recall": mean(mean(m["recall"] for m in ms) for ms in per_seed.values()), "n_seeds": len(per_seed), "per_dataset": per_ds, "protocol_note": "upper bound: oracle error positions + 20 gold-labeled tuples " "(its package default); damage=0 by construction", } json.dump(result, open(out, "w"), indent=1) print(f"\nBaran macro REAL-F1 {mu:.3f} ± {ci:.3f} (n={len(seed_f1)} seeds) -> {out}") return result if __name__ == "__main__": ap = argparse.ArgumentParser() ap.add_argument("--score-baran", action="store_true") args = ap.parse_args() if args.score_baran: score_baran()