Spaces:
Sleeping
Sleeping
File size: 19,220 Bytes
f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 f936921 cd02ce7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 | import json
nb = {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.10.0"},
"accelerator": "GPU"
},
"cells": [
# ββ Cell 0: markdown intro ββββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "markdown", "id": "md0", "metadata": {},
"source": [
"# DataCentric-Env β GRPO Training Notebook\n\n",
"**End-to-end runnable by a judge.** Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO via TRL + Unsloth.\n\n",
"**Runtime required:** GPU (T4 or better). In Colab: `Runtime β Change runtime type β T4 GPU`\n\n",
"## What this notebook does\n",
"1. Installs Unsloth, TRL, and dependencies\n",
"2. Loads Qwen2.5-3B-Instruct in 4-bit with LoRA adapters\n",
"3. Builds a prompt dataset from live environment observations\n",
"4. Defines a **reward function** that calls the live environment `/step` endpoint\n",
"5. Trains with `GRPOTrainer` β the reward function IS the environment verifier\n",
"6. Inspects sampled generations during training to catch reward hacking\n",
"7. Saves the model using the **Unsloth merge path** (not naive `save_pretrained`)\n\n",
"> **Grader note:** Format compliance is Grader 1, completely independent of accuracy.\n",
"> All reward values are clamped to strictly `(0.001, 0.999)` β never 0.0 or 1.0."
]
},
# ββ Cell 1: Install βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 1: Install dependencies\n",
"# Unsloth install β use the correct pip source for Colab\n",
"import subprocess, sys\n",
"subprocess.run([sys.executable, '-m', 'pip', 'install',\n",
" 'unsloth', 'trl>=0.12.0', 'transformers>=4.45.0',\n",
" 'accelerate', 'peft', 'datasets', 'requests', 'bitsandbytes'\n",
"], check=True)\n",
"print('All dependencies installed.')"
]
},
# ββ Cell 2: Imports + config βββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 2: Imports and configuration\n",
"import json, re, requests, torch\n",
"from unsloth import FastLanguageModel\n",
"from trl import GRPOTrainer, GRPOConfig\n",
"from datasets import Dataset\n",
"\n",
"# Live environment URL β deployed on HuggingFace Spaces\n",
"ENV_URL = 'https://aswini-kumar-datacentric-env.hf.space'\n",
"\n",
"# Verify the environment is reachable before training\n",
"health = requests.get(f'{ENV_URL}/health', timeout=30)\n",
"assert health.status_code == 200, f'Environment not reachable: {health.status_code}'\n",
"print(f'Environment health: {health.json()}')\n",
"\n",
"SYSTEM_PROMPT = (\n",
" 'You are a data quality agent. You receive dataset statistics and must choose '\n",
" 'which specialist tool to call to improve the dataset so a downstream classifier '\n",
" 'performs better.\\n\\n'\n",
" 'Always respond with valid JSON in this exact format:\\n'\n",
" '{\"agent\": \"<tool_name>\", \"target\": \"<column_or_all>\", \"strategy\": \"<strategy_name>\"}\\n\\n'\n",
" 'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n",
" 'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n",
" 'Balancer strategies: undersample\\n'\n",
" 'Relabeler: use when labels are noisy, costs 2 budget points.'\n",
")\n",
"\n",
"def build_prompt(obs: dict) -> str:\n",
" return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n",
"\n",
"print('Config ready. ENV_URL:', ENV_URL)"
]
},
# ββ Cell 3: Model setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 3: Load Qwen2.5-3B-Instruct in 4-bit with LoRA\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name='unsloth/Qwen2.5-3B-Instruct',\n",
" max_seq_length=1024,\n",
" load_in_4bit=True,\n",
" dtype=None, # auto\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=16,\n",
" lora_alpha=32,\n",
" lora_dropout=0.0,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
" use_gradient_checkpointing='unsloth',\n",
" random_state=42,\n",
")\n",
"\n",
"print(f'Model loaded. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
]
},
# ββ Cell 4: Build prompt dataset βββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 4: Build prompt dataset\n",
"# Collect N initial observations by calling /reset.\n",
"# GRPOTrainer will generate multiple completions per prompt and score each via reward_fn.\n",
"\n",
"N_PROMPTS = 60 # number of distinct starting states\n",
"prompts = []\n",
"\n",
"difficulties = ['easy'] * 30 + ['medium'] * 20 + ['hard'] * 10 # curriculum mix\n",
"\n",
"for i, diff in enumerate(difficulties):\n",
" obs = requests.post(f'{ENV_URL}/reset', json={'difficulty': diff}, timeout=30).json()\n",
" prompts.append({\n",
" 'prompt': build_prompt(obs),\n",
" 'difficulty': diff,\n",
" 'initial_obs': json.dumps(obs),\n",
" })\n",
" if i % 10 == 0:\n",
" print(f' Collected {i+1}/{N_PROMPTS} prompts (difficulty={diff})')\n",
"\n",
"dataset = Dataset.from_list(prompts)\n",
"print(f'Dataset ready: {len(dataset)} prompts')\n",
"print('Sample prompt (truncated):', dataset[0]['prompt'][:300])"
]
},
# ββ Cell 5: Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 5: Reward function β this IS the environment verifier\n",
"#\n",
"# TRL GRPOTrainer calls reward_fn(completions, prompts, **kwargs) -> list[float]\n",
"# For each completion:\n",
"# 1. Parse the JSON action from the model output\n",
"# 2. Call /reset to get a fresh env state\n",
"# 3. Call /step with the parsed action\n",
"# 4. Return the environment reward (already in (0.001, 0.999))\n",
"#\n",
"# Grader 1 (format compliance) fires FIRST, independently of all other graders.\n",
"# Final reward is always clamped by the server before returning.\n",
"\n",
"def parse_action(text: str) -> dict:\n",
" \"\"\"Extract JSON action from model output. Returns fallback on parse failure.\"\"\"\n",
" text = text.strip()\n",
" # Try to extract JSON block\n",
" match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n",
" if match:\n",
" try:\n",
" return json.loads(match.group(0))\n",
" except Exception:\n",
" pass\n",
" # Fallback β intentionally malformed so Grader 1 fires 0.001\n",
" return {'agent': '__invalid__'}\n",
"\n",
"\n",
"def reward_fn(completions, prompts=None, **kwargs):\n",
" \"\"\"\n",
" Called by GRPOTrainer once per batch.\n",
" Returns list[float] of rewards, one per completion.\n",
" All values are strictly in (0.001, 0.999) β enforced by the server.\n",
" \"\"\"\n",
" rewards = []\n",
" for completion in completions:\n",
" try:\n",
" action = parse_action(completion)\n",
" # Fresh env state for each reward evaluation\n",
" requests.post(f'{ENV_URL}/reset', json={}, timeout=20)\n",
" result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
" reward = float(result.get('reward', 0.001))\n",
" # Safety clamp β should already be clamped by server\n",
" reward = max(0.001, min(0.999, reward))\n",
" except Exception as e:\n",
" print(f' [reward_fn error] {e}')\n",
" reward = 0.001 # format/connection failure -> minimum\n",
" rewards.append(reward)\n",
" return rewards\n",
"\n",
"\n",
"# Quick sanity check before training\n",
"test_completions = [\n",
" '{\"agent\": \"cleaner\", \"target\": \"all\", \"strategy\": \"median_impute\"}',\n",
" 'some random text that is not JSON',\n",
" '{\"agent\": \"balancer\", \"strategy\": \"undersample\"}',\n",
"]\n",
"test_rewards = reward_fn(test_completions)\n",
"print('Reward function sanity check:')\n",
"for c, r in zip(test_completions, test_rewards):\n",
" print(f' action={c[:60]!r:65s} -> reward={r:.4f}')\n",
"\n",
"assert all(0.0 < r < 1.0 for r in test_rewards), 'Reward out of valid range!'\n",
"print('All rewards in (0.0, 1.0): PASS')"
]
},
# ββ Cell 6: Generation inspection helper βββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 6: Generation inspection β reward hacking detection\n",
"#\n",
"# Run this BEFORE and AFTER training to verify:\n",
"# - Agent produces valid JSON actions\n",
"# - Reward rises because accuracy rises, NOT because of exploits\n",
"# - Agent varies its tool selection based on observation\n",
"\n",
"def inspect_generations(n_episodes=5, label='baseline'):\n",
" print(f'\\n=== Generation inspection: {label} ===')\n",
" tool_counts = {}\n",
" valid_json_count = 0\n",
" rewards = []\n",
"\n",
" for ep in range(n_episodes):\n",
" obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
" prompt = build_prompt(obs)\n",
"\n",
" FastLanguageModel.for_inference(model)\n",
" inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n",
" with torch.no_grad():\n",
" out = model.generate(**inputs, max_new_tokens=80, temperature=0.7, do_sample=True)\n",
" FastLanguageModel.for_training(model)\n",
"\n",
" response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
" action = parse_action(response)\n",
" agent = action.get('agent', '__invalid__')\n",
"\n",
" # Step in env\n",
" result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
" reward = result.get('reward', 0.0)\n",
" acc_before = result.get('info', {}).get('prev_accuracy', '?')\n",
" acc_after = result.get('info', {}).get('new_accuracy', '?')\n",
"\n",
" is_valid = isinstance(action.get('agent'), str) and action['agent'] in [\n",
" 'cleaner', 'augmenter', 'balancer', 'relabeler', 'validator'\n",
" ]\n",
" if is_valid:\n",
" valid_json_count += 1\n",
" tool_counts[agent] = tool_counts.get(agent, 0) + 1\n",
" rewards.append(reward)\n",
"\n",
" print(f' Ep{ep+1}: agent={agent:12s} reward={reward:.4f} '\n",
" f'acc {acc_before}->{acc_after} '\n",
" f'raw_output={response[:80]!r}')\n",
"\n",
" print(f'\\n Valid JSON rate: {valid_json_count}/{n_episodes}')\n",
" print(f' Tool distribution: {tool_counts}')\n",
" print(f' Mean reward: {sum(rewards)/len(rewards):.4f}')\n",
" print()\n",
"\n",
" # Reward hacking checks\n",
" if len(tool_counts) == 1:\n",
" print(' β οΈ WARNING: Agent always picks same tool β possible reward hacking!')\n",
" else:\n",
" print(' β
Tool diversity OK β no same-tool spamming detected.')\n",
"\n",
" return {'valid_json_rate': valid_json_count/n_episodes, 'mean_reward': sum(rewards)/len(rewards), 'tool_counts': tool_counts}\n",
"\n",
"\n",
"# Run BEFORE training β establishes the baseline\n",
"before_stats = inspect_generations(n_episodes=5, label='BEFORE training')"
]
},
# ββ Cell 7: GRPO Training βββββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 7: GRPO Training\n",
"#\n",
"# GRPOTrainer:\n",
"# - samples num_generations completions per prompt\n",
"# - calls reward_fn on all completions\n",
"# - updates model to increase probability of higher-reward completions\n",
"#\n",
"# The reward_fn IS the environment verifier β no learned reward model needed.\n",
"\n",
"config = GRPOConfig(\n",
" output_dir='./datacentric-grpo',\n",
" num_train_epochs=3,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" num_generations=4, # 4 completions per prompt, ranked by reward\n",
" max_completion_length=100, # action JSON is short\n",
" learning_rate=5e-6,\n",
" logging_steps=5,\n",
" save_steps=50,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type='cosine',\n",
" report_to='none', # swap to 'wandb' for live curves\n",
" remove_unused_columns=False,\n",
")\n",
"\n",
"trainer = GRPOTrainer(\n",
" model=model,\n",
" args=config,\n",
" reward_funcs=reward_fn, # environment verifier IS the reward function\n",
" train_dataset=dataset,\n",
" tokenizer=tokenizer,\n",
")\n",
"\n",
"print('Starting GRPO training...')\n",
"print(f' Dataset: {len(dataset)} prompts')\n",
"print(f' Generations per prompt: {config.num_generations}')\n",
"print(f' Total reward evaluations per epoch: {len(dataset) * config.num_generations}')\n",
"print()\n",
"\n",
"trainer.train()\n",
"\n",
"print('Training complete.')"
]
},
# ββ Cell 8: Post-training inspection βββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c8", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 8: Post-training generation inspection\n",
"# Compare BEFORE vs AFTER to verify no reward hacking occurred.\n",
"# Expected: higher valid JSON rate, higher mean reward, varied tool selection.\n",
"\n",
"after_stats = inspect_generations(n_episodes=5, label='AFTER training')\n",
"\n",
"print('=== Before vs After comparison ===')\n",
"print(f' Valid JSON rate: {before_stats[\"valid_json_rate\"]:.1%} -> {after_stats[\"valid_json_rate\"]:.1%}')\n",
"print(f' Mean reward: {before_stats[\"mean_reward\"]:.4f} -> {after_stats[\"mean_reward\"]:.4f}')\n",
"print(f' Tool diversity: {len(before_stats[\"tool_counts\"])} tools -> {len(after_stats[\"tool_counts\"])} tools')\n",
"\n",
"if after_stats['mean_reward'] > before_stats['mean_reward']:\n",
" print(' β
Reward improved after training β learning signal working.')\n",
"else:\n",
" print(' β οΈ Reward did not improve β check env connectivity and reward function.')"
]
},
# ββ Cell 9: Save model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
{
"cell_type": "code", "id": "c9", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 9: Save model using Unsloth merge path\n",
"#\n",
"# CRITICAL: Do NOT use naive save_pretrained on a 4-bit model.\n",
"# That upcasts to 16-bit then merges naively, which damages model quality.\n",
"# Use save_pretrained_merged with save_method='merged_16bit' instead.\n",
"\n",
"SAVE_DIR = 'datacentric-grpo-final'\n",
"\n",
"model.save_pretrained_merged(\n",
" SAVE_DIR,\n",
" tokenizer,\n",
" save_method='merged_16bit', # correct Unsloth merge path\n",
")\n",
"\n",
"print(f'Model saved to {SAVE_DIR}/')\n",
"print('Test inference immediately to verify model quality before uploading.')\n",
"\n",
"# Quick inference test\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"test_tok = AutoTokenizer.from_pretrained(SAVE_DIR)\n",
"test_model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, torch_dtype=torch.float16, device_map='auto')\n",
"\n",
"obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
"inp = test_tok(build_prompt(obs), return_tensors='pt').to('cuda')\n",
"with torch.no_grad():\n",
" out = test_model.generate(**inp, max_new_tokens=80)\n",
"response = test_tok.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n",
"print(f'Post-save inference test: {response}')\n",
"action = parse_action(response)\n",
"print(f'Parsed action: {action}')\n",
"print('Model save and inference verified.')"
]
},
]
}
with open("training/train.ipynb", "w") as f:
json.dump(nb, f, indent=1)
print("Notebook written successfully.")
|