File size: 19,220 Bytes
f936921
 
 
 
 
 
 
cd02ce7
 
f936921
 
cd02ce7
f936921
cd02ce7
f936921
 
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
 
cd02ce7
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
cd02ce7
 
f936921
 
 
 
cd02ce7
 
 
 
 
 
 
f936921
 
 
 
 
 
 
 
 
 
 
 
cd02ce7
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
cd02ce7
f936921
 
 
 
cd02ce7
f936921
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
cd02ce7
 
 
f936921
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
 
 
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f936921
 
 
cd02ce7
 
 
 
 
 
 
 
 
 
 
f936921
 
 
 
 
cd02ce7
f936921
 
 
 
cd02ce7
 
 
 
 
 
 
 
 
f936921
 
cd02ce7
f936921
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f936921
cd02ce7
 
 
 
 
 
 
 
f936921
cd02ce7
f936921
cd02ce7
f936921
cd02ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f936921
cd02ce7
f936921
 
 
 
 
 
cd02ce7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import json

nb = {
 "nbformat": 4,
 "nbformat_minor": 5,
 "metadata": {
   "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
   "language_info": {"name": "python", "version": "3.10.0"},
   "accelerator": "GPU"
 },
 "cells": [
  # ── Cell 0: markdown intro ────────────────────────────────────────────────────
  {
    "cell_type": "markdown", "id": "md0", "metadata": {},
    "source": [
      "# DataCentric-Env β€” GRPO Training Notebook\n\n",
      "**End-to-end runnable by a judge.** Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO via TRL + Unsloth.\n\n",
      "**Runtime required:** GPU (T4 or better). In Colab: `Runtime β†’ Change runtime type β†’ T4 GPU`\n\n",
      "## What this notebook does\n",
      "1. Installs Unsloth, TRL, and dependencies\n",
      "2. Loads Qwen2.5-3B-Instruct in 4-bit with LoRA adapters\n",
      "3. Builds a prompt dataset from live environment observations\n",
      "4. Defines a **reward function** that calls the live environment `/step` endpoint\n",
      "5. Trains with `GRPOTrainer` β€” the reward function IS the environment verifier\n",
      "6. Inspects sampled generations during training to catch reward hacking\n",
      "7. Saves the model using the **Unsloth merge path** (not naive `save_pretrained`)\n\n",
      "> **Grader note:** Format compliance is Grader 1, completely independent of accuracy.\n",
      "> All reward values are clamped to strictly `(0.001, 0.999)` β€” never 0.0 or 1.0."
    ]
  },
  # ── Cell 1: Install ───────────────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 1: Install dependencies\n",
      "# Unsloth install β€” use the correct pip source for Colab\n",
      "import subprocess, sys\n",
      "subprocess.run([sys.executable, '-m', 'pip', 'install',\n",
      "    'unsloth', 'trl>=0.12.0', 'transformers>=4.45.0',\n",
      "    'accelerate', 'peft', 'datasets', 'requests', 'bitsandbytes'\n",
      "], check=True)\n",
      "print('All dependencies installed.')"
    ]
  },
  # ── Cell 2: Imports + config ─────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 2: Imports and configuration\n",
      "import json, re, requests, torch\n",
      "from unsloth import FastLanguageModel\n",
      "from trl import GRPOTrainer, GRPOConfig\n",
      "from datasets import Dataset\n",
      "\n",
      "# Live environment URL β€” deployed on HuggingFace Spaces\n",
      "ENV_URL = 'https://aswini-kumar-datacentric-env.hf.space'\n",
      "\n",
      "# Verify the environment is reachable before training\n",
      "health = requests.get(f'{ENV_URL}/health', timeout=30)\n",
      "assert health.status_code == 200, f'Environment not reachable: {health.status_code}'\n",
      "print(f'Environment health: {health.json()}')\n",
      "\n",
      "SYSTEM_PROMPT = (\n",
      "    'You are a data quality agent. You receive dataset statistics and must choose '\n",
      "    'which specialist tool to call to improve the dataset so a downstream classifier '\n",
      "    'performs better.\\n\\n'\n",
      "    'Always respond with valid JSON in this exact format:\\n'\n",
      "    '{\"agent\": \"<tool_name>\", \"target\": \"<column_or_all>\", \"strategy\": \"<strategy_name>\"}\\n\\n'\n",
      "    'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n",
      "    'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n",
      "    'Balancer strategies: undersample\\n'\n",
      "    'Relabeler: use when labels are noisy, costs 2 budget points.'\n",
      ")\n",
      "\n",
      "def build_prompt(obs: dict) -> str:\n",
      "    return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n",
      "\n",
      "print('Config ready. ENV_URL:', ENV_URL)"
    ]
  },
  # ── Cell 3: Model setup ───────────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 3: Load Qwen2.5-3B-Instruct in 4-bit with LoRA\n",
      "model, tokenizer = FastLanguageModel.from_pretrained(\n",
      "    model_name='unsloth/Qwen2.5-3B-Instruct',\n",
      "    max_seq_length=1024,\n",
      "    load_in_4bit=True,\n",
      "    dtype=None,  # auto\n",
      ")\n",
      "\n",
      "model = FastLanguageModel.get_peft_model(\n",
      "    model,\n",
      "    r=16,\n",
      "    lora_alpha=32,\n",
      "    lora_dropout=0.0,\n",
      "    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
      "    use_gradient_checkpointing='unsloth',\n",
      "    random_state=42,\n",
      ")\n",
      "\n",
      "print(f'Model loaded. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
    ]
  },
  # ── Cell 4: Build prompt dataset ─────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 4: Build prompt dataset\n",
      "# Collect N initial observations by calling /reset.\n",
      "# GRPOTrainer will generate multiple completions per prompt and score each via reward_fn.\n",
      "\n",
      "N_PROMPTS = 60  # number of distinct starting states\n",
      "prompts = []\n",
      "\n",
      "difficulties = ['easy'] * 30 + ['medium'] * 20 + ['hard'] * 10  # curriculum mix\n",
      "\n",
      "for i, diff in enumerate(difficulties):\n",
      "    obs = requests.post(f'{ENV_URL}/reset', json={'difficulty': diff}, timeout=30).json()\n",
      "    prompts.append({\n",
      "        'prompt': build_prompt(obs),\n",
      "        'difficulty': diff,\n",
      "        'initial_obs': json.dumps(obs),\n",
      "    })\n",
      "    if i % 10 == 0:\n",
      "        print(f'  Collected {i+1}/{N_PROMPTS} prompts (difficulty={diff})')\n",
      "\n",
      "dataset = Dataset.from_list(prompts)\n",
      "print(f'Dataset ready: {len(dataset)} prompts')\n",
      "print('Sample prompt (truncated):', dataset[0]['prompt'][:300])"
    ]
  },
  # ── Cell 5: Reward function ───────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 5: Reward function β€” this IS the environment verifier\n",
      "#\n",
      "# TRL GRPOTrainer calls reward_fn(completions, prompts, **kwargs) -> list[float]\n",
      "# For each completion:\n",
      "#   1. Parse the JSON action from the model output\n",
      "#   2. Call /reset to get a fresh env state\n",
      "#   3. Call /step with the parsed action\n",
      "#   4. Return the environment reward (already in (0.001, 0.999))\n",
      "#\n",
      "# Grader 1 (format compliance) fires FIRST, independently of all other graders.\n",
      "# Final reward is always clamped by the server before returning.\n",
      "\n",
      "def parse_action(text: str) -> dict:\n",
      "    \"\"\"Extract JSON action from model output. Returns fallback on parse failure.\"\"\"\n",
      "    text = text.strip()\n",
      "    # Try to extract JSON block\n",
      "    match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n",
      "    if match:\n",
      "        try:\n",
      "            return json.loads(match.group(0))\n",
      "        except Exception:\n",
      "            pass\n",
      "    # Fallback β€” intentionally malformed so Grader 1 fires 0.001\n",
      "    return {'agent': '__invalid__'}\n",
      "\n",
      "\n",
      "def reward_fn(completions, prompts=None, **kwargs):\n",
      "    \"\"\"\n",
      "    Called by GRPOTrainer once per batch.\n",
      "    Returns list[float] of rewards, one per completion.\n",
      "    All values are strictly in (0.001, 0.999) β€” enforced by the server.\n",
      "    \"\"\"\n",
      "    rewards = []\n",
      "    for completion in completions:\n",
      "        try:\n",
      "            action = parse_action(completion)\n",
      "            # Fresh env state for each reward evaluation\n",
      "            requests.post(f'{ENV_URL}/reset', json={}, timeout=20)\n",
      "            result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
      "            reward = float(result.get('reward', 0.001))\n",
      "            # Safety clamp β€” should already be clamped by server\n",
      "            reward = max(0.001, min(0.999, reward))\n",
      "        except Exception as e:\n",
      "            print(f'  [reward_fn error] {e}')\n",
      "            reward = 0.001  # format/connection failure -> minimum\n",
      "        rewards.append(reward)\n",
      "    return rewards\n",
      "\n",
      "\n",
      "# Quick sanity check before training\n",
      "test_completions = [\n",
      "    '{\"agent\": \"cleaner\", \"target\": \"all\", \"strategy\": \"median_impute\"}',\n",
      "    'some random text that is not JSON',\n",
      "    '{\"agent\": \"balancer\", \"strategy\": \"undersample\"}',\n",
      "]\n",
      "test_rewards = reward_fn(test_completions)\n",
      "print('Reward function sanity check:')\n",
      "for c, r in zip(test_completions, test_rewards):\n",
      "    print(f'  action={c[:60]!r:65s} -> reward={r:.4f}')\n",
      "\n",
      "assert all(0.0 < r < 1.0 for r in test_rewards), 'Reward out of valid range!'\n",
      "print('All rewards in (0.0, 1.0): PASS')"
    ]
  },
  # ── Cell 6: Generation inspection helper ─────────────────────────────────────
  {
    "cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 6: Generation inspection β€” reward hacking detection\n",
      "#\n",
      "# Run this BEFORE and AFTER training to verify:\n",
      "#   - Agent produces valid JSON actions\n",
      "#   - Reward rises because accuracy rises, NOT because of exploits\n",
      "#   - Agent varies its tool selection based on observation\n",
      "\n",
      "def inspect_generations(n_episodes=5, label='baseline'):\n",
      "    print(f'\\n=== Generation inspection: {label} ===')\n",
      "    tool_counts = {}\n",
      "    valid_json_count = 0\n",
      "    rewards = []\n",
      "\n",
      "    for ep in range(n_episodes):\n",
      "        obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
      "        prompt = build_prompt(obs)\n",
      "\n",
      "        FastLanguageModel.for_inference(model)\n",
      "        inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n",
      "        with torch.no_grad():\n",
      "            out = model.generate(**inputs, max_new_tokens=80, temperature=0.7, do_sample=True)\n",
      "        FastLanguageModel.for_training(model)\n",
      "\n",
      "        response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
      "        action = parse_action(response)\n",
      "        agent = action.get('agent', '__invalid__')\n",
      "\n",
      "        # Step in env\n",
      "        result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
      "        reward = result.get('reward', 0.0)\n",
      "        acc_before = result.get('info', {}).get('prev_accuracy', '?')\n",
      "        acc_after  = result.get('info', {}).get('new_accuracy', '?')\n",
      "\n",
      "        is_valid = isinstance(action.get('agent'), str) and action['agent'] in [\n",
      "            'cleaner', 'augmenter', 'balancer', 'relabeler', 'validator'\n",
      "        ]\n",
      "        if is_valid:\n",
      "            valid_json_count += 1\n",
      "        tool_counts[agent] = tool_counts.get(agent, 0) + 1\n",
      "        rewards.append(reward)\n",
      "\n",
      "        print(f'  Ep{ep+1}: agent={agent:12s} reward={reward:.4f}  '\n",
      "              f'acc {acc_before}->{acc_after}  '\n",
      "              f'raw_output={response[:80]!r}')\n",
      "\n",
      "    print(f'\\n  Valid JSON rate: {valid_json_count}/{n_episodes}')\n",
      "    print(f'  Tool distribution: {tool_counts}')\n",
      "    print(f'  Mean reward: {sum(rewards)/len(rewards):.4f}')\n",
      "    print()\n",
      "\n",
      "    # Reward hacking checks\n",
      "    if len(tool_counts) == 1:\n",
      "        print('  ⚠️  WARNING: Agent always picks same tool β€” possible reward hacking!')\n",
      "    else:\n",
      "        print('  βœ… Tool diversity OK β€” no same-tool spamming detected.')\n",
      "\n",
      "    return {'valid_json_rate': valid_json_count/n_episodes, 'mean_reward': sum(rewards)/len(rewards), 'tool_counts': tool_counts}\n",
      "\n",
      "\n",
      "# Run BEFORE training β€” establishes the baseline\n",
      "before_stats = inspect_generations(n_episodes=5, label='BEFORE training')"
    ]
  },
  # ── Cell 7: GRPO Training ─────────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 7: GRPO Training\n",
      "#\n",
      "# GRPOTrainer:\n",
      "#   - samples num_generations completions per prompt\n",
      "#   - calls reward_fn on all completions\n",
      "#   - updates model to increase probability of higher-reward completions\n",
      "#\n",
      "# The reward_fn IS the environment verifier β€” no learned reward model needed.\n",
      "\n",
      "config = GRPOConfig(\n",
      "    output_dir='./datacentric-grpo',\n",
      "    num_train_epochs=3,\n",
      "    per_device_train_batch_size=1,\n",
      "    gradient_accumulation_steps=4,\n",
      "    num_generations=4,          # 4 completions per prompt, ranked by reward\n",
      "    max_completion_length=100,  # action JSON is short\n",
      "    learning_rate=5e-6,\n",
      "    logging_steps=5,\n",
      "    save_steps=50,\n",
      "    warmup_ratio=0.1,\n",
      "    lr_scheduler_type='cosine',\n",
      "    report_to='none',           # swap to 'wandb' for live curves\n",
      "    remove_unused_columns=False,\n",
      ")\n",
      "\n",
      "trainer = GRPOTrainer(\n",
      "    model=model,\n",
      "    args=config,\n",
      "    reward_funcs=reward_fn,     # environment verifier IS the reward function\n",
      "    train_dataset=dataset,\n",
      "    tokenizer=tokenizer,\n",
      ")\n",
      "\n",
      "print('Starting GRPO training...')\n",
      "print(f'  Dataset: {len(dataset)} prompts')\n",
      "print(f'  Generations per prompt: {config.num_generations}')\n",
      "print(f'  Total reward evaluations per epoch: {len(dataset) * config.num_generations}')\n",
      "print()\n",
      "\n",
      "trainer.train()\n",
      "\n",
      "print('Training complete.')"
    ]
  },
  # ── Cell 8: Post-training inspection ─────────────────────────────────────────
  {
    "cell_type": "code", "id": "c8", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 8: Post-training generation inspection\n",
      "# Compare BEFORE vs AFTER to verify no reward hacking occurred.\n",
      "# Expected: higher valid JSON rate, higher mean reward, varied tool selection.\n",
      "\n",
      "after_stats = inspect_generations(n_episodes=5, label='AFTER training')\n",
      "\n",
      "print('=== Before vs After comparison ===')\n",
      "print(f'  Valid JSON rate: {before_stats[\"valid_json_rate\"]:.1%} -> {after_stats[\"valid_json_rate\"]:.1%}')\n",
      "print(f'  Mean reward:     {before_stats[\"mean_reward\"]:.4f} -> {after_stats[\"mean_reward\"]:.4f}')\n",
      "print(f'  Tool diversity:  {len(before_stats[\"tool_counts\"])} tools -> {len(after_stats[\"tool_counts\"])} tools')\n",
      "\n",
      "if after_stats['mean_reward'] > before_stats['mean_reward']:\n",
      "    print('  βœ… Reward improved after training β€” learning signal working.')\n",
      "else:\n",
      "    print('  ⚠️  Reward did not improve β€” check env connectivity and reward function.')"
    ]
  },
  # ── Cell 9: Save model ────────────────────────────────────────────────────────
  {
    "cell_type": "code", "id": "c9", "metadata": {}, "outputs": [], "execution_count": None,
    "source": [
      "# Cell 9: Save model using Unsloth merge path\n",
      "#\n",
      "# CRITICAL: Do NOT use naive save_pretrained on a 4-bit model.\n",
      "# That upcasts to 16-bit then merges naively, which damages model quality.\n",
      "# Use save_pretrained_merged with save_method='merged_16bit' instead.\n",
      "\n",
      "SAVE_DIR = 'datacentric-grpo-final'\n",
      "\n",
      "model.save_pretrained_merged(\n",
      "    SAVE_DIR,\n",
      "    tokenizer,\n",
      "    save_method='merged_16bit',   # correct Unsloth merge path\n",
      ")\n",
      "\n",
      "print(f'Model saved to {SAVE_DIR}/')\n",
      "print('Test inference immediately to verify model quality before uploading.')\n",
      "\n",
      "# Quick inference test\n",
      "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
      "test_tok = AutoTokenizer.from_pretrained(SAVE_DIR)\n",
      "test_model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, torch_dtype=torch.float16, device_map='auto')\n",
      "\n",
      "obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
      "inp = test_tok(build_prompt(obs), return_tensors='pt').to('cuda')\n",
      "with torch.no_grad():\n",
      "    out = test_model.generate(**inp, max_new_tokens=80)\n",
      "response = test_tok.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n",
      "print(f'Post-save inference test: {response}')\n",
      "action = parse_action(response)\n",
      "print(f'Parsed action: {action}')\n",
      "print('Model save and inference verified.')"
    ]
  },
 ]
}

with open("training/train.ipynb", "w") as f:
    json.dump(nb, f, indent=1)

print("Notebook written successfully.")