File size: 6,160 Bytes
16dc556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""WS4 learned-repair baselines: scoring + Jellyfish prompt construction.

Both baselines bypass plan dicts (the executor is column-level by design; learned repair
is per-cell) — they produce repaired DataFrames scored by the SAME churn-neutral
`eval.run_real_multi.score` as every other row of the money table.

* Baran: repaired CSVs come from eval/run_baran.py (pinned env). Score here:
      uv run python -m eval.baselines_learned --score-baran
* Jellyfish: prompts built here (unit-testable without a GPU), executed by
  scripts/modal_jellyfish.py (vLLM on Modal), scored in-run with the same `score`.

Jellyfish has NO repair task — we compose its two published cell-level tasks:
error detection (yes/no per cell) then data imputation (infer the flagged cell with the
attribute removed). Prompt templates are verbatim from the NECOUDBFM/Jellyfish-13B model
card; this composition is OURS, not theirs (disclosed in the paper).
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

SYSTEM_MESSAGE = ("You are an AI assistant that follows instruction extremely well. "
                  "Help as much as you can.")

_ED_TEMPLATE = (
    "Your task is to determine if there is an error in the value of a specific "
    "attribute within the whole record provided.\n"
    "The attributes may include {attrs}.\n"
    "Errors may include, but are not limited to, spelling errors, inconsistencies, "
    "or values that don't make sense given the context of the whole record.\n"
    "Record [{record}]\n"
    "Attribute for Verification: [{col}: {val}]\n"
    "Question: Is there an error in the value of {col}? "
    "Choose your answer from: [Yes, No]."
)

_DI_TEMPLATE = (
    "You are presented with a {keyword} record that is missing a specific attribute: "
    "{col}.\n"
    "Your task is to deduce or infer the value of {col} using the available "
    "information in the record.\n"
    "You may be provided with fields like {attrs} to help you in the inference.\n"
    "Record: [{record}]\n"
    "Based on the provided record, what would you infer is the value for the missing "
    "attribute {col}?\n"
    "Answer only the value of {col}."
)


def wrap_prompt(user_message: str) -> str:
    """The Jellyfish-13B chat scaffold (verbatim from the model card)."""
    return f"{SYSTEM_MESSAGE}\n\n### Instruction:\n\n{user_message}\n\n### Response:\n\n"


def _serialize(record: dict, skip: str | None = None) -> str:
    return ", ".join(f"{k}: {v}" for k, v in record.items() if k != skip)


def ed_prompt(record: dict, col: str) -> str:
    """Error-detection prompt (whole-record form) for one cell."""
    return wrap_prompt(_ED_TEMPLATE.format(
        attrs=", ".join(record.keys()), record=_serialize(record),
        col=col, val=record[col]))


def di_prompt(record: dict, col: str, keyword: str) -> str:
    """Data-imputation prompt for a flagged cell — the attribute is REMOVED from the
    serialized record so the model infers, not copies."""
    attrs = [k for k in record.keys() if k != col]
    return wrap_prompt(_DI_TEMPLATE.format(
        keyword=keyword, col=col, attrs=", ".join(attrs),
        record=_serialize(record, skip=col)))


def parse_ed(text: str) -> bool:
    """True = the model says the cell is erroneous."""
    return text.strip().lower().lstrip("[").startswith("yes")


def parse_di(text: str, original: str) -> str:
    """Imputed value, or the original (abstain) when the answer is unusable —
    empty, multi-line/rambling, or implausibly long for a cell."""
    ans = text.strip().strip('"').strip()
    if not ans or "\n" in ans or len(ans) > 80:
        return original
    return ans


# ---------------------------------------------------------------- Baran scoring

def score_baran(repaired_dir: str = "eval/results/baran",
                out: str = "eval/results/baran_raha.json") -> dict:
    """Score every <name>_seed<k>_repaired.csv against (dirty, clean) under the
    identical churn-neutral protocol; macro REAL-F1 mean ± 95% CI over seeds."""
    import collections

    import pandas as pd

    from .run_real_multi import _raha_pair, score

    per_seed: dict[int, list] = collections.defaultdict(list)
    per_ds = []
    for p in sorted(Path(repaired_dir).glob("*_seed*_repaired.csv")):
        name, seed = p.stem.rsplit("_repaired", 1)[0].rsplit("_seed", 1)
        repaired = pd.read_csv(p, dtype=str, keep_default_na=False)
        dirty, clean = _raha_pair(name)
        m = score(dirty, clean, repaired)
        per_seed[int(seed)].append(m)
        per_ds.append({"name": name, "seed": int(seed), **{k: v for k, v in m.items()}})
        print(f"  {name:<10} seed{seed}: F1={m['f1']:.3f} P={m['precision']:.3f} "
              f"R={m['recall']:.3f} dmg={m['damage']:.3f}")
    if not per_seed:
        raise SystemExit(f"no repaired CSVs found in {repaired_dir}")

    def mean(xs):
        xs = list(xs)
        return sum(xs) / len(xs) if xs else 0.0

    seed_f1 = [mean(m["f1"] for m in ms) for ms in per_seed.values()]
    mu = mean(seed_f1)
    var = mean([(x - mu) ** 2 for x in seed_f1])
    ci = 1.96 * (var ** 0.5) / (len(seed_f1) ** 0.5)
    result = {
        "system": "Baran (oracle detection, 20 gold labels)",
        "real_f1": mu, "real_f1_ci": ci, "real_f1_per_seed": seed_f1,
        "damage": mean(mean(m["damage"] for m in ms) for ms in per_seed.values()),
        "precision": mean(mean(m["precision"] for m in ms) for ms in per_seed.values()),
        "recall": mean(mean(m["recall"] for m in ms) for ms in per_seed.values()),
        "n_seeds": len(per_seed), "per_dataset": per_ds,
        "protocol_note": "upper bound: oracle error positions + 20 gold-labeled tuples "
                         "(its package default); damage=0 by construction",
    }
    json.dump(result, open(out, "w"), indent=1)
    print(f"\nBaran macro REAL-F1 {mu:.3f} ± {ci:.3f} (n={len(seed_f1)} seeds) -> {out}")
    return result


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--score-baran", action="store_true")
    args = ap.parse_args()
    if args.score_baran:
        score_baran()