scrubdata / scripts /modal_jellyfish.py
OpenAI Codex
deploy: add sponsor:openai tag (Best Use of Codex) + Codex-hardened build
16dc556
Raw
History Blame Contribute Delete
4.84 kB
"""WS4 baseline: Jellyfish-13B (Zhang et al., EMNLP 2024) on the Raha real slice.
Composes its two published cell-level tasks — error detection (yes/no) then data
imputation (infer the flagged cell) — into repairs scored under our churn-neutral
protocol. Prompts verbatim from the model card (built in eval/baselines_learned.py);
recommended decoding (temp 0.35, top_p 0.9, rep-penalty 1.15) via vLLM.
Caveats disclosed in the paper: hospital is in Jellyfish's training data (published ED
F1 95.6); flights/rayyan are in its eval suite; the ED+DI composition is ours.
uv run modal run --detach scripts/modal_jellyfish.py --datasets hospital # sanity
uv run modal run --detach scripts/modal_jellyfish.py # full
"""
import modal
IGNORE = [".venv/**", ".git/**", "*.gguf", "**/__pycache__/**", ".gstack/**",
"design/**", "frontend/variant_*/**", "notebooks/**", ".pytest_cache/**",
"data/**", "eval/results/**"]
image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install("vllm==0.11.2", "pandas", "huggingface_hub")
.add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True)
)
app = modal.App("scrubdata-jellyfish", image=image)
hf_cache = modal.Volume.from_name("scrubdata-hf-cache", create_if_missing=True)
results = modal.Dict.from_name("scrubdata-jellyfish-results", create_if_missing=True)
KEYWORDS = {"hospital": "hospital", "beers": "beer", "flights": "flight",
"rayyan": "bibliography", "movies_1": "movie"}
@app.function(gpu="A100-80GB", timeout=4 * 3600, volumes={"/hf": hf_cache})
def run_jellyfish(model_id: str = "NECOUDBFM/Jellyfish-13B",
datasets: str = "hospital,beers,flights,rayyan,movies_1"):
import os
import sys
os.chdir("/root/repo")
sys.path.insert(0, "/root/repo")
os.environ["HF_HOME"] = "/hf"
from vllm import LLM, SamplingParams
from eval.baselines_learned import di_prompt, ed_prompt, parse_di, parse_ed
from eval.run_real_multi import _raha_pair, score
llm = LLM(model=model_id, dtype="bfloat16", download_dir="/hf")
# card-recommended decoding; stop must stay "### Instruction:"
ed_params = SamplingParams(temperature=0.35, top_p=0.9, repetition_penalty=1.15,
max_tokens=6, stop=["### Instruction:"])
di_params = SamplingParams(temperature=0.35, top_p=0.9, repetition_penalty=1.15,
max_tokens=64, stop=["### Instruction:"])
out = {}
for name in datasets.split(","):
dirty, clean = _raha_pair(name)
records = dirty.to_dict(orient="records")
cells = [(i, col) for i, rec in enumerate(records) for col in dirty.columns]
print(f"{name}: {len(cells)} ED prompts", flush=True)
ed_out = llm.generate([ed_prompt(records[i], col) for i, col in cells], ed_params)
flagged = [(i, col) for (i, col), o in zip(cells, ed_out)
if parse_ed(o.outputs[0].text)]
print(f"{name}: {len(flagged)} flagged -> DI", flush=True)
kw = KEYWORDS.get(name, "data")
di_out = llm.generate([di_prompt(records[i], col, kw) for i, col in flagged],
di_params)
repaired = dirty.copy()
for (i, col), o in zip(flagged, di_out):
repaired.loc[i, col] = parse_di(o.outputs[0].text, str(dirty.loc[i, col]))
m = score(dirty, clean, repaired)
out[name] = {"f1": m["f1"], "precision": m["precision"], "recall": m["recall"],
"damage": m["damage"], "n_flagged": len(flagged),
"n_cells": len(cells)}
results[f"{model_id.rsplit('/', 1)[-1]}:{name}"] = {
**out[name], "repaired_csv": repaired.to_csv(index=False)}
print(f" {name}: F1={m['f1']:.3f} P={m['precision']:.3f} R={m['recall']:.3f} "
f"dmg={m['damage']:.3f}", flush=True)
def mean(xs):
xs = list(xs)
return sum(xs) / len(xs) if xs else 0.0
summary = {"system": f"Jellyfish ED+DI ({model_id})",
"real_f1": mean(d["f1"] for d in out.values()),
"damage": mean(d["damage"] for d in out.values()),
"precision": mean(d["precision"] for d in out.values()),
"recall": mean(d["recall"] for d in out.values()),
"per_dataset": out}
results[f"{model_id.rsplit('/', 1)[-1]}:summary"] = summary
print("\nJELLYFISH summary:", {k: round(v, 3) for k, v in summary.items()
if isinstance(v, float)})
return summary
@app.local_entrypoint()
def main(model_id: str = "NECOUDBFM/Jellyfish-13B",
datasets: str = "hospital,beers,flights,rayyan,movies_1"):
call = run_jellyfish.spawn(model_id=model_id, datasets=datasets)
print(f"Launched detached. call_id={call.object_id}")