Spaces:
Running
Running
File size: 6,160 Bytes
16dc556 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """WS4 learned-repair baselines: scoring + Jellyfish prompt construction.
Both baselines bypass plan dicts (the executor is column-level by design; learned repair
is per-cell) — they produce repaired DataFrames scored by the SAME churn-neutral
`eval.run_real_multi.score` as every other row of the money table.
* Baran: repaired CSVs come from eval/run_baran.py (pinned env). Score here:
uv run python -m eval.baselines_learned --score-baran
* Jellyfish: prompts built here (unit-testable without a GPU), executed by
scripts/modal_jellyfish.py (vLLM on Modal), scored in-run with the same `score`.
Jellyfish has NO repair task — we compose its two published cell-level tasks:
error detection (yes/no per cell) then data imputation (infer the flagged cell with the
attribute removed). Prompt templates are verbatim from the NECOUDBFM/Jellyfish-13B model
card; this composition is OURS, not theirs (disclosed in the paper).
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
SYSTEM_MESSAGE = ("You are an AI assistant that follows instruction extremely well. "
"Help as much as you can.")
_ED_TEMPLATE = (
"Your task is to determine if there is an error in the value of a specific "
"attribute within the whole record provided.\n"
"The attributes may include {attrs}.\n"
"Errors may include, but are not limited to, spelling errors, inconsistencies, "
"or values that don't make sense given the context of the whole record.\n"
"Record [{record}]\n"
"Attribute for Verification: [{col}: {val}]\n"
"Question: Is there an error in the value of {col}? "
"Choose your answer from: [Yes, No]."
)
_DI_TEMPLATE = (
"You are presented with a {keyword} record that is missing a specific attribute: "
"{col}.\n"
"Your task is to deduce or infer the value of {col} using the available "
"information in the record.\n"
"You may be provided with fields like {attrs} to help you in the inference.\n"
"Record: [{record}]\n"
"Based on the provided record, what would you infer is the value for the missing "
"attribute {col}?\n"
"Answer only the value of {col}."
)
def wrap_prompt(user_message: str) -> str:
"""The Jellyfish-13B chat scaffold (verbatim from the model card)."""
return f"{SYSTEM_MESSAGE}\n\n### Instruction:\n\n{user_message}\n\n### Response:\n\n"
def _serialize(record: dict, skip: str | None = None) -> str:
return ", ".join(f"{k}: {v}" for k, v in record.items() if k != skip)
def ed_prompt(record: dict, col: str) -> str:
"""Error-detection prompt (whole-record form) for one cell."""
return wrap_prompt(_ED_TEMPLATE.format(
attrs=", ".join(record.keys()), record=_serialize(record),
col=col, val=record[col]))
def di_prompt(record: dict, col: str, keyword: str) -> str:
"""Data-imputation prompt for a flagged cell — the attribute is REMOVED from the
serialized record so the model infers, not copies."""
attrs = [k for k in record.keys() if k != col]
return wrap_prompt(_DI_TEMPLATE.format(
keyword=keyword, col=col, attrs=", ".join(attrs),
record=_serialize(record, skip=col)))
def parse_ed(text: str) -> bool:
"""True = the model says the cell is erroneous."""
return text.strip().lower().lstrip("[").startswith("yes")
def parse_di(text: str, original: str) -> str:
"""Imputed value, or the original (abstain) when the answer is unusable —
empty, multi-line/rambling, or implausibly long for a cell."""
ans = text.strip().strip('"').strip()
if not ans or "\n" in ans or len(ans) > 80:
return original
return ans
# ---------------------------------------------------------------- Baran scoring
def score_baran(repaired_dir: str = "eval/results/baran",
out: str = "eval/results/baran_raha.json") -> dict:
"""Score every <name>_seed<k>_repaired.csv against (dirty, clean) under the
identical churn-neutral protocol; macro REAL-F1 mean ± 95% CI over seeds."""
import collections
import pandas as pd
from .run_real_multi import _raha_pair, score
per_seed: dict[int, list] = collections.defaultdict(list)
per_ds = []
for p in sorted(Path(repaired_dir).glob("*_seed*_repaired.csv")):
name, seed = p.stem.rsplit("_repaired", 1)[0].rsplit("_seed", 1)
repaired = pd.read_csv(p, dtype=str, keep_default_na=False)
dirty, clean = _raha_pair(name)
m = score(dirty, clean, repaired)
per_seed[int(seed)].append(m)
per_ds.append({"name": name, "seed": int(seed), **{k: v for k, v in m.items()}})
print(f" {name:<10} seed{seed}: F1={m['f1']:.3f} P={m['precision']:.3f} "
f"R={m['recall']:.3f} dmg={m['damage']:.3f}")
if not per_seed:
raise SystemExit(f"no repaired CSVs found in {repaired_dir}")
def mean(xs):
xs = list(xs)
return sum(xs) / len(xs) if xs else 0.0
seed_f1 = [mean(m["f1"] for m in ms) for ms in per_seed.values()]
mu = mean(seed_f1)
var = mean([(x - mu) ** 2 for x in seed_f1])
ci = 1.96 * (var ** 0.5) / (len(seed_f1) ** 0.5)
result = {
"system": "Baran (oracle detection, 20 gold labels)",
"real_f1": mu, "real_f1_ci": ci, "real_f1_per_seed": seed_f1,
"damage": mean(mean(m["damage"] for m in ms) for ms in per_seed.values()),
"precision": mean(mean(m["precision"] for m in ms) for ms in per_seed.values()),
"recall": mean(mean(m["recall"] for m in ms) for ms in per_seed.values()),
"n_seeds": len(per_seed), "per_dataset": per_ds,
"protocol_note": "upper bound: oracle error positions + 20 gold-labeled tuples "
"(its package default); damage=0 by construction",
}
json.dump(result, open(out, "w"), indent=1)
print(f"\nBaran macro REAL-F1 {mu:.3f} ± {ci:.3f} (n={len(seed_f1)} seeds) -> {out}")
return result
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--score-baran", action="store_true")
args = ap.parse_args()
if args.score_baran:
score_baran()
|