File size: 5,183 Bytes
16dc556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Run the v4 fine-tune eval on a Modal GPU — fast, unlike local Q8 (250s/call timeouts).

Loads base + the v4 LoRA adapter in bf16 (better fidelity than the GGUF), runs both eval
layers (synthetic matrix + real hospital, batched). Cost-bounded: L4 GPU, ~15 min.

    uv run modal run scripts/modal_eval.py            # default n=20 synthetic
    uv run modal run scripts/modal_eval.py --n 12
"""

import modal

IGNORE = [".venv", ".git", "data", "*.gguf", "**/__pycache__", ".gstack",
          "frontend/variant_a", "frontend/variant_b", "frontend/variant_c"]

image = (
    modal.Image.debian_slim(python_version="3.12")
    .pip_install("torch", "transformers", "peft", "accelerate",
                 "pandas", "jsonschema", "huggingface_hub", "pycountry", "sentencepiece")
    .add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True)
)

app = modal.App("scrubdata-eval", image=image)
# Persist results so a DETACHED run survives a dropped (cellular) client connection.
results = modal.Dict.from_name("scrubdata-eval-results", create_if_missing=True)


@app.function(gpu="L4", timeout=1800)
def run_eval(n_synth: int = 20):
    import os, sys
    os.chdir("/root/repo")
    sys.path.insert(0, "/root/repo")
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel

    from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt
    from scrubdata.profiler import profile_dataframe
    from scrubdata.model_planner import _extract_json, make_batched_planner
    from scrubdata.executor import apply_plan
    from scrubdata.planner import mock_plan
    from eval.run_eval import evaluate
    from eval.gold import load_gold
    from eval.run_real import _ensure_data, _load, _score

    base_id = "unsloth/Qwen3-4B-Instruct-2507"
    adapter_id = "ricalanis/scrubdata-qwen3-4b-v4"
    tok = AutoTokenizer.from_pretrained(adapter_id)
    base = AutoModelForCausalLM.from_pretrained(
        base_id, torch_dtype=torch.bfloat16, device_map="cuda")
    model = PeftModel.from_pretrained(base, adapter_id).eval()

    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    eos_ids = [tok.eos_token_id, im_end] if im_end is not None else tok.eos_token_id

    def base_planner(df, *_):
        msgs = [{"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": build_user_prompt(profile_dataframe(df), df)}]
        enc = tok.apply_chat_template(msgs, add_generation_prompt=True,
                                      return_tensors="pt", return_dict=True)
        input_ids = enc["input_ids"].to(model.device)
        attn = enc["attention_mask"].to(model.device)
        with torch.no_grad():
            # stop at <|im_end|> so we don't run to max_new_tokens every call (was the
            # 50s/call slowdown that blew the timeout); attn mask silences the warning.
            out = model.generate(input_ids=input_ids, attention_mask=attn,
                                 max_new_tokens=4000, do_sample=False,
                                 eos_token_id=eos_ids, pad_token_id=tok.eos_token_id)
        text = tok.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
        plan = _extract_json(text)
        if plan is None:
            return {"__error__": "no_json"}
        plan.setdefault("table_operations", [])
        plan.setdefault("columns", [])
        plan.setdefault("flags", [])
        return plan

    out = {}
    # Layer 1 — synthetic frozen gold
    gold = load_gold()[:n_synth]
    systems = {"ORACLE": lambda df, gp: gp,
               "HEURISTIC": lambda df, gp: mock_plan(df),
               "FT_v4": base_planner}
    out["layer1"] = {name: evaluate(fn, gold) for name, fn in systems.items()}

    # Layer 2 — real hospital (batched)
    _ensure_data()
    dirty, clean = _load()
    ft_plan = make_batched_planner(base_planner, batch_size=4)(dirty)
    cleaned, _ = apply_plan(dirty, ft_plan)
    out["layer2_ft"] = _score(dirty, clean, cleaned)
    out["layer2_noop"] = _score(dirty, clean, dirty)

    table = _format(out)
    print(table)                       # goes to Modal logs
    results["latest"] = {"out": out, "table": table}   # survives client disconnect
    return out


def _format(r) -> str:
    lines = ["\n=== Layer 1 (synthetic) ==="]
    cols = ["json_valid", "op_f1", "canon_f1", "canon_r", "recovery"]
    lines.append(f"{'system':<12}" + "".join(f"{c:>11}" for c in cols))
    for name, m in r["layer1"].items():
        lines.append(f"{name:<12}" + "".join(f"{m[c]:>11.3f}" for c in cols))
    lines.append("\n=== Layer 2 (real hospital) ===")
    for k in ("layer2_noop", "layer2_ft"):
        m = r[k]
        lines.append(f"{k:<12} repair_recall={m['repair_recall']:.3f} "
                     f"repair_prec={m['repair_prec']:.3f} recovery={m['recovery']:.3f} "
                     f"fixed={m['_fixed']}/{m['_errors']}")
    return "\n".join(lines)


@app.local_entrypoint()
def main(n: int = 20):
    # Attached: block for the result and print it. (For flaky connections, the function
    # also persists to the `results` Dict, so `--detach` + Dict-fetch still works.)
    print(_format(run_eval.remote(n_synth=n)))