File size: 7,476 Bytes
16dc556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Fast standalone eval of the v5 adapter saved in the Modal Volume.

The in-training eval was slow because it ran the 4-bit (bitsandbytes) training model.
Here we load the base in BF16 + the LoRA adapter from the volume and merge -> fast
generation. Reports synthetic gold + the real Raha hospital repair_recall (the headline).

    uv run modal run scripts/modal_eval_v5.py
"""

import modal

IGNORE = [".venv/**", ".git/**", "*.gguf", "**/__pycache__/**", ".gstack/**",
          "design/**", "frontend/variant_*/**", "notebooks/**", ".pytest_cache/**", "data/**"]

image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install("torch", "transformers>=4.45", "peft", "accelerate",
                 "pandas", "jsonschema", "pycountry", "sentencepiece")
    .add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True)
    # harvested EVAL-ONLY pairs (data/** is ignored above; Raha sets auto-download
    # in-container, these only exist locally via the stage-2/3 harvesters)
    .add_local_dir("data/real/ed2_restaurants",
                   "/root/repo/data/real/ed2_restaurants", copy=True)
    .add_local_dir("data/real/tt_co23z7go", "/root/repo/data/real/tt_co23z7go", copy=True)
    .add_local_dir("data/real/tt_uma1dnf6", "/root/repo/data/real/tt_uma1dnf6", copy=True)
    .add_local_dir("data/real/zeroed_billionaire",
                   "/root/repo/data/real/zeroed_billionaire", copy=True)
    # entity-reference vocabularies (reconcile.default_index loads from data/)
    .add_local_file("training/harvests/toughtables_ref.jsonl",
                    "/root/repo/data/toughtables_ref.jsonl", copy=True)
    .add_local_file("training/harvests/musicbrainz_hint_aliases.jsonl",
                    "/root/repo/data/musicbrainz_hint_aliases.jsonl", copy=True)
    .add_local_file("training/harvests/wikidata_company_aliases.jsonl",
                    "/root/repo/data/wikidata_company_aliases.jsonl", copy=True)
    .add_local_file("training/harvests/ror_aliases.jsonl",
                    "/root/repo/data/ror_aliases.jsonl", copy=True)
)
app = modal.App("scrubdata-eval-v5", image=image)
adapter_vol = modal.Volume.from_name("scrubdata-v5-adapter")
results = modal.Dict.from_name("scrubdata-eval-v5-results", create_if_missing=True)


@app.function(gpu="A100-80GB", timeout=7200, volumes={"/vol": adapter_vol})
def run_eval(n_synth: int = 20, adapter: str = "/vol/v5", skip_real: bool = False,
             pair_profiles: bool = False, capture: str = ""):
    import os, sys, torch
    os.chdir("/root/repo")
    sys.path.insert(0, "/root/repo")
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from peft import PeftModel

    from scrubdata.prompt import SYSTEM_PROMPT, build_user_prompt
    from scrubdata.profiler import profile_dataframe
    from scrubdata.model_planner import _extract_json, make_batched_planner
    from scrubdata.executor import apply_plan
    from scrubdata.planner import mock_plan
    from eval.run_eval import evaluate
    from eval.gold import load_gold
    from eval.run_real import _ensure_data, _load, _score

    base_id = "unsloth/Qwen3-4B-Instruct-2507"
    tok = AutoTokenizer.from_pretrained(base_id)
    base = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.bfloat16, device_map="cuda")
    model = PeftModel.from_pretrained(base, adapter).merge_and_unload()  # bf16-native merge
    model.eval()
    model.config.use_cache = True

    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    eos_ids = [tok.eos_token_id, im_end] if im_end is not None else tok.eos_token_id

    def base_planner(df, *_):
        pairs = None
        if pair_profiles:
            from scrubdata.pair_profile import pairs_for_df
            pairs = pairs_for_df(df)
        msgs = [{"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user",
                 "content": build_user_prompt(profile_dataframe(df), df,
                                              candidate_pairs=pairs)}]
        enc = tok.apply_chat_template(msgs, add_generation_prompt=True,
                                      return_tensors="pt", return_dict=True)
        ids = enc["input_ids"].to(model.device)
        with torch.no_grad():
            out = model.generate(input_ids=ids, attention_mask=enc["attention_mask"].to(model.device),
                                 max_new_tokens=2200, do_sample=False, eos_token_id=eos_ids,
                                 pad_token_id=tok.eos_token_id, use_cache=True,
                                 suppress_tokens=[151657, 151658])  # block <tool_call> loop
        text = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True)
        plan = _extract_json(text)
        if plan is None:
            return {"__error__": "no_json"}
        plan.setdefault("table_operations", [])
        plan.setdefault("columns", [])
        plan.setdefault("flags", [])
        if pairs is not None:
            from scrubdata.pair_profile import constrain_plan
            plan = constrain_plan(plan, pairs)
        return plan

    out = {}
    gold = load_gold()[:n_synth]
    out["layer1"] = {name: evaluate(fn, gold) for name, fn in {
        "HEURISTIC": lambda df, gp: mock_plan(df), "FT_v5": base_planner}.items()}
    if not skip_real:
        _ensure_data()
        dirty, clean = _load()
        ft_plan = make_batched_planner(base_planner, batch_size=4)(dirty)
        cleaned, _ = apply_plan(dirty, ft_plan)
        out["hospital_ft"] = _score(dirty, clean, cleaned)
        out["hospital_noop"] = _score(dirty, clean, dirty)
        out["hospital_plan"] = ft_plan          # raw plan for local precision-curve sweeps

    if capture:
        # capture raw grounded model plans for arbitrary eval datasets (GEN metric:
        # plans are applied + scored locally with the full union pipeline). Tables
        # are loaded FULL (same loader contract as eval/generalization.py).
        from eval.run_real_multi import _fetch
        out["plans"] = {}
        for name in capture.split(","):
            dirty, _clean = _fetch(name)
            print(f"capturing plan: {name} ({len(dirty)} rows)", flush=True)
            out["plans"][name] = make_batched_planner(base_planner, batch_size=4)(dirty)

    table = _format(out)
    print(table)
    key = adapter.rsplit("/", 1)[-1] if adapter != "/vol/v5" else "latest"
    if pair_profiles:
        key += "_pairs"
    results[key] = {"out": out, "table": table}
    return out


def _format(r) -> str:
    L = ["\n=== Layer 1 (synthetic) ==="]
    cols = ["json_valid", "op_f1", "canon_f1", "recovery"]
    L.append(f"{'system':<12}" + "".join(f"{c:>11}" for c in cols))
    for name, m in r["layer1"].items():
        L.append(f"{name:<12}" + "".join(f"{m[c]:>11.3f}" for c in cols))
    if "hospital_ft" not in r:
        return "\n".join(L)
    L.append("\n=== Real hospital ===")
    for k in ("hospital_noop", "hospital_ft"):
        m = r[k]
        L.append(f"{k:<13} repair_recall={m['repair_recall']:.3f} "
                 f"repair_prec={m['repair_prec']:.3f} recovery={m['recovery']:.3f}")
    return "\n".join(L)


@app.local_entrypoint()
def main(adapter: str = "/vol/v5", skip_real: bool = False, n_synth: int = 20,
         pair_profiles: bool = False, capture: str = ""):
    call = run_eval.spawn(adapter=adapter, skip_real=skip_real, n_synth=n_synth,
                          pair_profiles=pair_profiles, capture=capture)
    print(f"Launched detached. call_id={call.object_id}")