Spaces:
Running
Running
| """GRPO pilot: RL the 4B planner against OUR EXECUTOR as the verifiable reward. | |
| Hand-rolled GRPO loop (TRL 0.14-0.17 all hard-require vllm at GRPO import in this | |
| stack; the algorithm is ~100 lines and the pilot question is signal, not framework | |
| purity): per step, sample G completions for one episode prompt, reward each by | |
| EXECUTING the plan against the episode's clean slice, normalize advantages within | |
| the group, take a policy-gradient step on LoRA params. No KL-ref term in the pilot | |
| (LoRA r16 + lr 1e-5 bounds drift; disclosed). | |
| Reward: invalid JSON -1.0; +0.2 valid JSON; +0.2 schema-valid; | |
| +2.0 * churn-neutral F1 − 4.0 * damage; execution exception −0.5. | |
| CONTROL ARM (--control): random rewards, identical config (Spurious Rewards check). | |
| uv run modal run --detach scripts/modal_grpo.py # main, 150 steps | |
| uv run modal run --detach scripts/modal_grpo.py --control --steps 100 | |
| """ | |
| import modal | |
| IGNORE = [".venv/**", ".git/**", "*.gguf", "**/__pycache__/**", ".gstack/**", | |
| "design/**", "frontend/variant_*/**", "notebooks/**", ".pytest_cache/**", | |
| "data/**", "eval/results/**"] | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .pip_install("torch", "transformers>=4.45", "peft", "accelerate", | |
| "pandas", "jsonschema", "pycountry", "sentencepiece") | |
| .add_local_dir(".", "/root/repo", ignore=IGNORE, copy=True) | |
| .add_local_file("data/grpo_episodes.jsonl", "/root/repo/data/grpo_episodes.jsonl", | |
| copy=True) | |
| ) | |
| app = modal.App("scrubdata-grpo", image=image) | |
| results = modal.Dict.from_name("scrubdata-train-results", create_if_missing=True) | |
| adapter_vol = modal.Volume.from_name("scrubdata-v5-adapter") | |
| def train_grpo(steps: int = 150, control: bool = False, seed: int = 0, | |
| group: int = 6, lr: float = 5e-6, max_new: int = 1024, | |
| kl_beta: float = 0.05, dest_name: str = ""): | |
| import io | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import torch | |
| os.chdir("/root/repo") | |
| sys.path.insert(0, "/root/repo") | |
| import pandas as pd | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from eval.metrics import is_valid | |
| from eval.run_real_multi import _cell_only, score | |
| from scrubdata.executor import apply_plan | |
| from scrubdata.model_planner import _extract_json | |
| torch.manual_seed(seed) | |
| rng = random.Random(seed) | |
| episodes = [json.loads(l) for l in open("data/grpo_episodes.jsonl")] | |
| rng.shuffle(episodes) | |
| base_id = "unsloth/Qwen3-4B-Instruct-2507" | |
| tok = AutoTokenizer.from_pretrained(base_id) | |
| model = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.bfloat16, | |
| device_map="cuda") | |
| model = get_peft_model(model, LoraConfig( | |
| r=16, lora_alpha=32, task_type="CAUSAL_LM", | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"])) | |
| model.train() | |
| opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=lr) | |
| im_end = tok.convert_tokens_to_ids("<|im_end|>") | |
| eos_ids = [tok.eos_token_id] + ([im_end] if im_end is not None else []) | |
| def reward(comp: str, ep) -> float: | |
| if control: | |
| return rng.random() | |
| plan = _extract_json(comp) | |
| if plan is None: | |
| return -1.0 | |
| r = 0.2 | |
| plan.setdefault("table_operations", []) | |
| plan.setdefault("columns", []) | |
| plan.setdefault("flags", []) | |
| if is_valid(plan): | |
| r += 0.2 | |
| try: | |
| dirty = pd.read_csv(io.StringIO(ep["dirty_csv"]), dtype=str, | |
| keep_default_na=False) | |
| clean = pd.read_csv(io.StringIO(ep["clean_csv"]), dtype=str, | |
| keep_default_na=False) | |
| cleaned, _ = apply_plan(dirty, _cell_only(plan)) | |
| m = score(dirty, clean, cleaned) | |
| r += 2.0 * m["f1"] - 4.0 * m["damage"] | |
| except Exception: # noqa: BLE001 | |
| r -= 0.5 | |
| return r | |
| curve = [] | |
| for step in range(steps): | |
| ep = episodes[step % len(episodes)] | |
| prompt = tok.apply_chat_template(ep["messages"], tokenize=False, | |
| add_generation_prompt=True) | |
| enc = tok(prompt, return_tensors="pt", truncation=True, max_length=2304) | |
| ids = enc["input_ids"].cuda() | |
| attn = enc["attention_mask"].cuda() | |
| with torch.no_grad(): | |
| gen = model.generate(input_ids=ids.repeat(group, 1), | |
| attention_mask=attn.repeat(group, 1), | |
| do_sample=True, temperature=0.9, top_p=0.95, | |
| max_new_tokens=max_new, eos_token_id=eos_ids, | |
| pad_token_id=tok.eos_token_id, | |
| suppress_tokens=[151657, 151658]) | |
| plen = ids.shape[1] | |
| comps = [tok.decode(g[plen:], skip_special_tokens=True) for g in gen] | |
| rs = torch.tensor([reward(c, ep) for c in comps], dtype=torch.float32) | |
| mean_r = rs.mean().item() | |
| curve.append((step, round(mean_r, 3))) | |
| if float(rs.std()) < 1e-5: | |
| continue # degenerate group: no signal | |
| adv = (rs - rs.mean()) / (rs.std() + 1e-6) | |
| # teacher-forced logprobs of sampled completions under the current policy | |
| opt.zero_grad() | |
| loss_total = 0.0 | |
| for g_seq, a in zip(gen, adv.tolist()): | |
| if abs(a) < 1e-6: | |
| continue | |
| seq = g_seq.unsqueeze(0) | |
| # completion-token labels: mask the prompt and everything after the | |
| # first eos in the completion region | |
| labels = seq.clone() | |
| labels[:, :plen] = -100 | |
| comp_region = seq[:, plen:] | |
| eos_pos = (comp_region == tok.eos_token_id) | \ | |
| ((comp_region == im_end) if im_end is not None else | |
| torch.zeros_like(comp_region, dtype=torch.bool)) | |
| after_first_eos = eos_pos.float().cumsum(dim=1) > 1 | |
| labels[:, plen:][after_first_eos] = -100 | |
| out = model(input_ids=seq) | |
| logits = out.logits[:, :-1] | |
| tgt = labels[:, 1:] | |
| mask = tgt != -100 | |
| lp = torch.log_softmax(logits.float(), dim=-1) | |
| tok_lp = lp.gather(-1, tgt.clamp(min=0).unsqueeze(-1)).squeeze(-1) | |
| mean_lp = (tok_lp * mask).sum() / mask.sum().clamp(min=1) | |
| # KL anchor to the frozen base (v2 fix: v1 ran unanchored and BOTH | |
| # arms destroyed JSON discipline — pure RL drift, caught by the | |
| # random-reward control). With LoRA the ref is free: disable adapters. | |
| kl = torch.tensor(0.0, device=seq.device) | |
| if kl_beta > 0: | |
| with torch.no_grad(), model.disable_adapter(): | |
| ref_logits = model(input_ids=seq).logits[:, :-1] | |
| ref_lp_tok = torch.log_softmax(ref_logits.float(), dim=-1).gather( | |
| -1, tgt.clamp(min=0).unsqueeze(-1)).squeeze(-1) | |
| kl = ((tok_lp - ref_lp_tok) * mask).sum() / mask.sum().clamp(min=1) | |
| loss = (-(a * mean_lp) + kl_beta * kl.abs()) / group | |
| loss.backward() | |
| loss_total += float(loss) | |
| torch.nn.utils.clip_grad_norm_( | |
| [p for p in model.parameters() if p.requires_grad], 1.0) | |
| opt.step() | |
| if step % 5 == 0: | |
| recent = [r for _, r in curve[-10:]] | |
| print(f"step {step}: reward {mean_r:.3f} (avg-10 " | |
| f"{sum(recent)/len(recent):.3f}) loss {loss_total:.4f}", flush=True) | |
| dest = dest_name or ("/vol/grpo_control" if control else "/vol/grpo_pilot") | |
| model.save_pretrained(dest) | |
| adapter_vol.commit() | |
| n25 = min(25, len(curve)) | |
| summary = {"arm": "control" if control else "main", "steps": steps, | |
| "reward_first10": curve[:10], "reward_last10": curve[-10:], | |
| "reward_mean_first25": round(sum(r for _, r in curve[:n25]) / n25, 3), | |
| "reward_mean_last25": round(sum(r for _, r in curve[-n25:]) / n25, 3), | |
| "adapter": dest} | |
| key = dest.rsplit("/", 1)[-1] | |
| results[key] = summary | |
| print("GRPO DONE:", summary) | |
| return summary | |
| def main(steps: int = 150, control: bool = False, dest_name: str = ""): | |
| call = train_grpo.spawn(steps=steps, control=control, dest_name=dest_name) | |
| print(f"Launched detached. call_id={call.object_id}") | |