File size: 8,686 Bytes
16dc556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""GRPO pilot: RL the 4B planner against OUR EXECUTOR as the verifiable reward.

Hand-rolled GRPO loop (TRL 0.14-0.17 all hard-require vllm at GRPO import in this
stack; the algorithm is ~100 lines and the pilot question is signal, not framework
purity): per step, sample G completions for one episode prompt, reward each by
EXECUTING the plan against the episode's clean slice, normalize advantages within
the group, take a policy-gradient step on LoRA params. No KL-ref term in the pilot
(LoRA r16 + lr 1e-5 bounds drift; disclosed).

Reward: invalid JSON -1.0; +0.2 valid JSON; +0.2 schema-valid;
        +2.0 * churn-neutral F1 − 4.0 * damage; execution exception −0.5.
CONTROL ARM (--control): random rewards, identical config (Spurious Rewards check).

    uv run modal run --detach scripts/modal_grpo.py                 # main, 150 steps
    uv run modal run --detach scripts/modal_grpo.py --control --steps 100
"""

import modal

IGNORE = [".venv/**", ".git/**", "*.gguf", "**/__pycache__/**", ".gstack/**",
          "design/**", "frontend/variant_*/**", "notebooks/**", ".pytest_cache/**",
          "data/**", "eval/results/**"]

image = (
    modal.Image.debian_slim(python_version="3.11")
    .pip_install("torch", "transformers>=4.45", "peft", "accelerate",
                 "pandas", "jsonschema", "pycountry", "sentencepiece")
    .add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True)
    .add_local_file("data/grpo_episodes.jsonl", "/root/repo/data/grpo_episodes.jsonl",
                    copy=True)
)
app = modal.App("scrubdata-grpo", image=image)
results = modal.Dict.from_name("scrubdata-train-results", create_if_missing=True)
adapter_vol = modal.Volume.from_name("scrubdata-v5-adapter")


@app.function(gpu="A100-80GB", timeout=4 * 3600, volumes={"/vol": adapter_vol})
def train_grpo(steps: int = 150, control: bool = False, seed: int = 0,
               group: int = 6, lr: float = 5e-6, max_new: int = 1024,
               kl_beta: float = 0.05, dest_name: str = ""):
    import io
    import json
    import os
    import random
    import sys

    import torch
    os.chdir("/root/repo")
    sys.path.insert(0, "/root/repo")
    import pandas as pd
    from peft import LoraConfig, get_peft_model
    from transformers import AutoModelForCausalLM, AutoTokenizer

    from eval.metrics import is_valid
    from eval.run_real_multi import _cell_only, score
    from scrubdata.executor import apply_plan
    from scrubdata.model_planner import _extract_json

    torch.manual_seed(seed)
    rng = random.Random(seed)
    episodes = [json.loads(l) for l in open("data/grpo_episodes.jsonl")]
    rng.shuffle(episodes)

    base_id = "unsloth/Qwen3-4B-Instruct-2507"
    tok = AutoTokenizer.from_pretrained(base_id)
    model = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.bfloat16,
                                                 device_map="cuda")
    model = get_peft_model(model, LoraConfig(
        r=16, lora_alpha=32, task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]))
    model.train()
    opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=lr)
    im_end = tok.convert_tokens_to_ids("<|im_end|>")
    eos_ids = [tok.eos_token_id] + ([im_end] if im_end is not None else [])

    def reward(comp: str, ep) -> float:
        if control:
            return rng.random()
        plan = _extract_json(comp)
        if plan is None:
            return -1.0
        r = 0.2
        plan.setdefault("table_operations", [])
        plan.setdefault("columns", [])
        plan.setdefault("flags", [])
        if is_valid(plan):
            r += 0.2
        try:
            dirty = pd.read_csv(io.StringIO(ep["dirty_csv"]), dtype=str,
                                keep_default_na=False)
            clean = pd.read_csv(io.StringIO(ep["clean_csv"]), dtype=str,
                                keep_default_na=False)
            cleaned, _ = apply_plan(dirty, _cell_only(plan))
            m = score(dirty, clean, cleaned)
            r += 2.0 * m["f1"] - 4.0 * m["damage"]
        except Exception:  # noqa: BLE001
            r -= 0.5
        return r

    curve = []
    for step in range(steps):
        ep = episodes[step % len(episodes)]
        prompt = tok.apply_chat_template(ep["messages"], tokenize=False,
                                         add_generation_prompt=True)
        enc = tok(prompt, return_tensors="pt", truncation=True, max_length=2304)
        ids = enc["input_ids"].cuda()
        attn = enc["attention_mask"].cuda()
        with torch.no_grad():
            gen = model.generate(input_ids=ids.repeat(group, 1),
                                 attention_mask=attn.repeat(group, 1),
                                 do_sample=True, temperature=0.9, top_p=0.95,
                                 max_new_tokens=max_new, eos_token_id=eos_ids,
                                 pad_token_id=tok.eos_token_id,
                                 suppress_tokens=[151657, 151658])
        plen = ids.shape[1]
        comps = [tok.decode(g[plen:], skip_special_tokens=True) for g in gen]
        rs = torch.tensor([reward(c, ep) for c in comps], dtype=torch.float32)
        mean_r = rs.mean().item()
        curve.append((step, round(mean_r, 3)))
        if float(rs.std()) < 1e-5:
            continue                                  # degenerate group: no signal
        adv = (rs - rs.mean()) / (rs.std() + 1e-6)
        # teacher-forced logprobs of sampled completions under the current policy
        opt.zero_grad()
        loss_total = 0.0
        for g_seq, a in zip(gen, adv.tolist()):
            if abs(a) < 1e-6:
                continue
            seq = g_seq.unsqueeze(0)
            # completion-token labels: mask the prompt and everything after the
            # first eos in the completion region
            labels = seq.clone()
            labels[:, :plen] = -100
            comp_region = seq[:, plen:]
            eos_pos = (comp_region == tok.eos_token_id) | \
                      ((comp_region == im_end) if im_end is not None else
                       torch.zeros_like(comp_region, dtype=torch.bool))
            after_first_eos = eos_pos.float().cumsum(dim=1) > 1
            labels[:, plen:][after_first_eos] = -100
            out = model(input_ids=seq)
            logits = out.logits[:, :-1]
            tgt = labels[:, 1:]
            mask = tgt != -100
            lp = torch.log_softmax(logits.float(), dim=-1)
            tok_lp = lp.gather(-1, tgt.clamp(min=0).unsqueeze(-1)).squeeze(-1)
            mean_lp = (tok_lp * mask).sum() / mask.sum().clamp(min=1)
            # KL anchor to the frozen base (v2 fix: v1 ran unanchored and BOTH
            # arms destroyed JSON discipline — pure RL drift, caught by the
            # random-reward control). With LoRA the ref is free: disable adapters.
            kl = torch.tensor(0.0, device=seq.device)
            if kl_beta > 0:
                with torch.no_grad(), model.disable_adapter():
                    ref_logits = model(input_ids=seq).logits[:, :-1]
                    ref_lp_tok = torch.log_softmax(ref_logits.float(), dim=-1).gather(
                        -1, tgt.clamp(min=0).unsqueeze(-1)).squeeze(-1)
                kl = ((tok_lp - ref_lp_tok) * mask).sum() / mask.sum().clamp(min=1)
            loss = (-(a * mean_lp) + kl_beta * kl.abs()) / group
            loss.backward()
            loss_total += float(loss)
        torch.nn.utils.clip_grad_norm_(
            [p for p in model.parameters() if p.requires_grad], 1.0)
        opt.step()
        if step % 5 == 0:
            recent = [r for _, r in curve[-10:]]
            print(f"step {step}: reward {mean_r:.3f} (avg-10 "
                  f"{sum(recent)/len(recent):.3f}) loss {loss_total:.4f}", flush=True)

    dest = dest_name or ("/vol/grpo_control" if control else "/vol/grpo_pilot")
    model.save_pretrained(dest)
    adapter_vol.commit()
    n25 = min(25, len(curve))
    summary = {"arm": "control" if control else "main", "steps": steps,
               "reward_first10": curve[:10], "reward_last10": curve[-10:],
               "reward_mean_first25": round(sum(r for _, r in curve[:n25]) / n25, 3),
               "reward_mean_last25": round(sum(r for _, r in curve[-n25:]) / n25, 3),
               "adapter": dest}
    key = dest.rsplit("/", 1)[-1]
    results[key] = summary
    print("GRPO DONE:", summary)
    return summary


@app.local_entrypoint()
def main(steps: int = 150, control: bool = False, dest_name: str = ""):
    call = train_grpo.spawn(steps=steps, control=control, dest_name=dest_name)
    print(f"Launched detached. call_id={call.object_id}")