Spaces:
Running
Running
| """WS4 baseline: Jellyfish-13B (Zhang et al., EMNLP 2024) on the Raha real slice. | |
| Composes its two published cell-level tasks — error detection (yes/no) then data | |
| imputation (infer the flagged cell) — into repairs scored under our churn-neutral | |
| protocol. Prompts verbatim from the model card (built in eval/baselines_learned.py); | |
| recommended decoding (temp 0.35, top_p 0.9, rep-penalty 1.15) via vLLM. | |
| Caveats disclosed in the paper: hospital is in Jellyfish's training data (published ED | |
| F1 95.6); flights/rayyan are in its eval suite; the ED+DI composition is ours. | |
| uv run modal run --detach scripts/modal_jellyfish.py --datasets hospital # sanity | |
| uv run modal run --detach scripts/modal_jellyfish.py # full | |
| """ | |
| import modal | |
| IGNORE = [".venv/**", ".git/**", "*.gguf", "**/__pycache__/**", ".gstack/**", | |
| "design/**", "frontend/variant_*/**", "notebooks/**", ".pytest_cache/**", | |
| "data/**", "eval/results/**"] | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .pip_install("vllm==0.11.2", "pandas", "huggingface_hub") | |
| .add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True) | |
| ) | |
| app = modal.App("scrubdata-jellyfish", image=image) | |
| hf_cache = modal.Volume.from_name("scrubdata-hf-cache", create_if_missing=True) | |
| results = modal.Dict.from_name("scrubdata-jellyfish-results", create_if_missing=True) | |
| KEYWORDS = {"hospital": "hospital", "beers": "beer", "flights": "flight", | |
| "rayyan": "bibliography", "movies_1": "movie"} | |
| def run_jellyfish(model_id: str = "NECOUDBFM/Jellyfish-13B", | |
| datasets: str = "hospital,beers,flights,rayyan,movies_1"): | |
| import os | |
| import sys | |
| os.chdir("/root/repo") | |
| sys.path.insert(0, "/root/repo") | |
| os.environ["HF_HOME"] = "/hf" | |
| from vllm import LLM, SamplingParams | |
| from eval.baselines_learned import di_prompt, ed_prompt, parse_di, parse_ed | |
| from eval.run_real_multi import _raha_pair, score | |
| llm = LLM(model=model_id, dtype="bfloat16", download_dir="/hf") | |
| # card-recommended decoding; stop must stay "### Instruction:" | |
| ed_params = SamplingParams(temperature=0.35, top_p=0.9, repetition_penalty=1.15, | |
| max_tokens=6, stop=["### Instruction:"]) | |
| di_params = SamplingParams(temperature=0.35, top_p=0.9, repetition_penalty=1.15, | |
| max_tokens=64, stop=["### Instruction:"]) | |
| out = {} | |
| for name in datasets.split(","): | |
| dirty, clean = _raha_pair(name) | |
| records = dirty.to_dict(orient="records") | |
| cells = [(i, col) for i, rec in enumerate(records) for col in dirty.columns] | |
| print(f"{name}: {len(cells)} ED prompts", flush=True) | |
| ed_out = llm.generate([ed_prompt(records[i], col) for i, col in cells], ed_params) | |
| flagged = [(i, col) for (i, col), o in zip(cells, ed_out) | |
| if parse_ed(o.outputs[0].text)] | |
| print(f"{name}: {len(flagged)} flagged -> DI", flush=True) | |
| kw = KEYWORDS.get(name, "data") | |
| di_out = llm.generate([di_prompt(records[i], col, kw) for i, col in flagged], | |
| di_params) | |
| repaired = dirty.copy() | |
| for (i, col), o in zip(flagged, di_out): | |
| repaired.loc[i, col] = parse_di(o.outputs[0].text, str(dirty.loc[i, col])) | |
| m = score(dirty, clean, repaired) | |
| out[name] = {"f1": m["f1"], "precision": m["precision"], "recall": m["recall"], | |
| "damage": m["damage"], "n_flagged": len(flagged), | |
| "n_cells": len(cells)} | |
| results[f"{model_id.rsplit('/', 1)[-1]}:{name}"] = { | |
| **out[name], "repaired_csv": repaired.to_csv(index=False)} | |
| print(f" {name}: F1={m['f1']:.3f} P={m['precision']:.3f} R={m['recall']:.3f} " | |
| f"dmg={m['damage']:.3f}", flush=True) | |
| def mean(xs): | |
| xs = list(xs) | |
| return sum(xs) / len(xs) if xs else 0.0 | |
| summary = {"system": f"Jellyfish ED+DI ({model_id})", | |
| "real_f1": mean(d["f1"] for d in out.values()), | |
| "damage": mean(d["damage"] for d in out.values()), | |
| "precision": mean(d["precision"] for d in out.values()), | |
| "recall": mean(d["recall"] for d in out.values()), | |
| "per_dataset": out} | |
| results[f"{model_id.rsplit('/', 1)[-1]}:summary"] = summary | |
| print("\nJELLYFISH summary:", {k: round(v, 3) for k, v in summary.items() | |
| if isinstance(v, float)}) | |
| return summary | |
| def main(model_id: str = "NECOUDBFM/Jellyfish-13B", | |
| datasets: str = "hospital,beers,flights,rayyan,movies_1"): | |
| call = run_jellyfish.spawn(model_id=model_id, datasets=datasets) | |
| print(f"Launched detached. call_id={call.object_id}") | |