import json nb = { "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python", "version": "3.10.0"}, "accelerator": "GPU" }, "cells": [ # ── Cell 0: markdown intro ──────────────────────────────────────────────────── { "cell_type": "markdown", "id": "md0", "metadata": {}, "source": [ "# DataCentric-Env — GRPO Training Notebook\n\n", "**End-to-end runnable by a judge.** Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO via TRL + Unsloth.\n\n", "**Runtime required:** GPU (T4 or better). In Colab: `Runtime → Change runtime type → T4 GPU`\n\n", "## What this notebook does\n", "1. Installs Unsloth, TRL, and dependencies\n", "2. Loads Qwen2.5-3B-Instruct in 4-bit with LoRA adapters\n", "3. Builds a prompt dataset from live environment observations\n", "4. Defines a **reward function** that calls the live environment `/step` endpoint\n", "5. Trains with `GRPOTrainer` — the reward function IS the environment verifier\n", "6. Inspects sampled generations during training to catch reward hacking\n", "7. Saves the model using the **Unsloth merge path** (not naive `save_pretrained`)\n\n", "> **Grader note:** Format compliance is Grader 1, completely independent of accuracy.\n", "> All reward values are clamped to strictly `(0.001, 0.999)` — never 0.0 or 1.0." ] }, # ── Cell 1: Install ─────────────────────────────────────────────────────────── { "cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 1: Install dependencies\n", "# Unsloth install — use the correct pip source for Colab\n", "import subprocess, sys\n", "subprocess.run([sys.executable, '-m', 'pip', 'install',\n", " 'unsloth', 'trl>=0.12.0', 'transformers>=4.45.0',\n", " 'accelerate', 'peft', 'datasets', 'requests', 'bitsandbytes'\n", "], check=True)\n", "print('All dependencies installed.')" ] }, # ── Cell 2: Imports + config ───────────────────────────────────────────────── { "cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 2: Imports and configuration\n", "import json, re, requests, torch\n", "from unsloth import FastLanguageModel\n", "from trl import GRPOTrainer, GRPOConfig\n", "from datasets import Dataset\n", "\n", "# Live environment URL — deployed on HuggingFace Spaces\n", "ENV_URL = 'https://aswini-kumar-datacentric-env.hf.space'\n", "\n", "# Verify the environment is reachable before training\n", "health = requests.get(f'{ENV_URL}/health', timeout=30)\n", "assert health.status_code == 200, f'Environment not reachable: {health.status_code}'\n", "print(f'Environment health: {health.json()}')\n", "\n", "SYSTEM_PROMPT = (\n", " 'You are a data quality agent. You receive dataset statistics and must choose '\n", " 'which specialist tool to call to improve the dataset so a downstream classifier '\n", " 'performs better.\\n\\n'\n", " 'Always respond with valid JSON in this exact format:\\n'\n", " '{\"agent\": \"\", \"target\": \"\", \"strategy\": \"\"}\\n\\n'\n", " 'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n", " 'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n", " 'Balancer strategies: undersample\\n'\n", " 'Relabeler: use when labels are noisy, costs 2 budget points.'\n", ")\n", "\n", "def build_prompt(obs: dict) -> str:\n", " return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n", "\n", "print('Config ready. ENV_URL:', ENV_URL)" ] }, # ── Cell 3: Model setup ─────────────────────────────────────────────────────── { "cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 3: Load Qwen2.5-3B-Instruct in 4-bit with LoRA\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name='unsloth/Qwen2.5-3B-Instruct',\n", " max_seq_length=1024,\n", " load_in_4bit=True,\n", " dtype=None, # auto\n", ")\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=16,\n", " lora_alpha=32,\n", " lora_dropout=0.0,\n", " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n", " use_gradient_checkpointing='unsloth',\n", " random_state=42,\n", ")\n", "\n", "print(f'Model loaded. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')" ] }, # ── Cell 4: Build prompt dataset ───────────────────────────────────────────── { "cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 4: Build prompt dataset\n", "# Collect N initial observations by calling /reset.\n", "# GRPOTrainer will generate multiple completions per prompt and score each via reward_fn.\n", "\n", "N_PROMPTS = 60 # number of distinct starting states\n", "prompts = []\n", "\n", "difficulties = ['easy'] * 30 + ['medium'] * 20 + ['hard'] * 10 # curriculum mix\n", "\n", "for i, diff in enumerate(difficulties):\n", " obs = requests.post(f'{ENV_URL}/reset', json={'difficulty': diff}, timeout=30).json()\n", " prompts.append({\n", " 'prompt': build_prompt(obs),\n", " 'difficulty': diff,\n", " 'initial_obs': json.dumps(obs),\n", " })\n", " if i % 10 == 0:\n", " print(f' Collected {i+1}/{N_PROMPTS} prompts (difficulty={diff})')\n", "\n", "dataset = Dataset.from_list(prompts)\n", "print(f'Dataset ready: {len(dataset)} prompts')\n", "print('Sample prompt (truncated):', dataset[0]['prompt'][:300])" ] }, # ── Cell 5: Reward function ─────────────────────────────────────────────────── { "cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 5: Reward function — this IS the environment verifier\n", "#\n", "# TRL GRPOTrainer calls reward_fn(completions, prompts, **kwargs) -> list[float]\n", "# For each completion:\n", "# 1. Parse the JSON action from the model output\n", "# 2. Call /reset to get a fresh env state\n", "# 3. Call /step with the parsed action\n", "# 4. Return the environment reward (already in (0.001, 0.999))\n", "#\n", "# Grader 1 (format compliance) fires FIRST, independently of all other graders.\n", "# Final reward is always clamped by the server before returning.\n", "\n", "def parse_action(text: str) -> dict:\n", " \"\"\"Extract JSON action from model output. Returns fallback on parse failure.\"\"\"\n", " text = text.strip()\n", " # Try to extract JSON block\n", " match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n", " if match:\n", " try:\n", " return json.loads(match.group(0))\n", " except Exception:\n", " pass\n", " # Fallback — intentionally malformed so Grader 1 fires 0.001\n", " return {'agent': '__invalid__'}\n", "\n", "\n", "def reward_fn(completions, prompts=None, **kwargs):\n", " \"\"\"\n", " Called by GRPOTrainer once per batch.\n", " Returns list[float] of rewards, one per completion.\n", " All values are strictly in (0.001, 0.999) — enforced by the server.\n", " \"\"\"\n", " rewards = []\n", " for completion in completions:\n", " try:\n", " action = parse_action(completion)\n", " # Fresh env state for each reward evaluation\n", " requests.post(f'{ENV_URL}/reset', json={}, timeout=20)\n", " result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n", " reward = float(result.get('reward', 0.001))\n", " # Safety clamp — should already be clamped by server\n", " reward = max(0.001, min(0.999, reward))\n", " except Exception as e:\n", " print(f' [reward_fn error] {e}')\n", " reward = 0.001 # format/connection failure -> minimum\n", " rewards.append(reward)\n", " return rewards\n", "\n", "\n", "# Quick sanity check before training\n", "test_completions = [\n", " '{\"agent\": \"cleaner\", \"target\": \"all\", \"strategy\": \"median_impute\"}',\n", " 'some random text that is not JSON',\n", " '{\"agent\": \"balancer\", \"strategy\": \"undersample\"}',\n", "]\n", "test_rewards = reward_fn(test_completions)\n", "print('Reward function sanity check:')\n", "for c, r in zip(test_completions, test_rewards):\n", " print(f' action={c[:60]!r:65s} -> reward={r:.4f}')\n", "\n", "assert all(0.0 < r < 1.0 for r in test_rewards), 'Reward out of valid range!'\n", "print('All rewards in (0.0, 1.0): PASS')" ] }, # ── Cell 6: Generation inspection helper ───────────────────────────────────── { "cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 6: Generation inspection — reward hacking detection\n", "#\n", "# Run this BEFORE and AFTER training to verify:\n", "# - Agent produces valid JSON actions\n", "# - Reward rises because accuracy rises, NOT because of exploits\n", "# - Agent varies its tool selection based on observation\n", "\n", "def inspect_generations(n_episodes=5, label='baseline'):\n", " print(f'\\n=== Generation inspection: {label} ===')\n", " tool_counts = {}\n", " valid_json_count = 0\n", " rewards = []\n", "\n", " for ep in range(n_episodes):\n", " obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n", " prompt = build_prompt(obs)\n", "\n", " FastLanguageModel.for_inference(model)\n", " inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n", " with torch.no_grad():\n", " out = model.generate(**inputs, max_new_tokens=80, temperature=0.7, do_sample=True)\n", " FastLanguageModel.for_training(model)\n", "\n", " response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", " action = parse_action(response)\n", " agent = action.get('agent', '__invalid__')\n", "\n", " # Step in env\n", " result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n", " reward = result.get('reward', 0.0)\n", " acc_before = result.get('info', {}).get('prev_accuracy', '?')\n", " acc_after = result.get('info', {}).get('new_accuracy', '?')\n", "\n", " is_valid = isinstance(action.get('agent'), str) and action['agent'] in [\n", " 'cleaner', 'augmenter', 'balancer', 'relabeler', 'validator'\n", " ]\n", " if is_valid:\n", " valid_json_count += 1\n", " tool_counts[agent] = tool_counts.get(agent, 0) + 1\n", " rewards.append(reward)\n", "\n", " print(f' Ep{ep+1}: agent={agent:12s} reward={reward:.4f} '\n", " f'acc {acc_before}->{acc_after} '\n", " f'raw_output={response[:80]!r}')\n", "\n", " print(f'\\n Valid JSON rate: {valid_json_count}/{n_episodes}')\n", " print(f' Tool distribution: {tool_counts}')\n", " print(f' Mean reward: {sum(rewards)/len(rewards):.4f}')\n", " print()\n", "\n", " # Reward hacking checks\n", " if len(tool_counts) == 1:\n", " print(' ⚠️ WARNING: Agent always picks same tool — possible reward hacking!')\n", " else:\n", " print(' ✅ Tool diversity OK — no same-tool spamming detected.')\n", "\n", " return {'valid_json_rate': valid_json_count/n_episodes, 'mean_reward': sum(rewards)/len(rewards), 'tool_counts': tool_counts}\n", "\n", "\n", "# Run BEFORE training — establishes the baseline\n", "before_stats = inspect_generations(n_episodes=5, label='BEFORE training')" ] }, # ── Cell 7: GRPO Training ───────────────────────────────────────────────────── { "cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 7: GRPO Training\n", "#\n", "# GRPOTrainer:\n", "# - samples num_generations completions per prompt\n", "# - calls reward_fn on all completions\n", "# - updates model to increase probability of higher-reward completions\n", "#\n", "# The reward_fn IS the environment verifier — no learned reward model needed.\n", "\n", "config = GRPOConfig(\n", " output_dir='./datacentric-grpo',\n", " num_train_epochs=3,\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " num_generations=4, # 4 completions per prompt, ranked by reward\n", " max_completion_length=100, # action JSON is short\n", " learning_rate=5e-6,\n", " logging_steps=5,\n", " save_steps=50,\n", " warmup_ratio=0.1,\n", " lr_scheduler_type='cosine',\n", " report_to='none', # swap to 'wandb' for live curves\n", " remove_unused_columns=False,\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " args=config,\n", " reward_funcs=reward_fn, # environment verifier IS the reward function\n", " train_dataset=dataset,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "print('Starting GRPO training...')\n", "print(f' Dataset: {len(dataset)} prompts')\n", "print(f' Generations per prompt: {config.num_generations}')\n", "print(f' Total reward evaluations per epoch: {len(dataset) * config.num_generations}')\n", "print()\n", "\n", "trainer.train()\n", "\n", "print('Training complete.')" ] }, # ── Cell 8: Post-training inspection ───────────────────────────────────────── { "cell_type": "code", "id": "c8", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 8: Post-training generation inspection\n", "# Compare BEFORE vs AFTER to verify no reward hacking occurred.\n", "# Expected: higher valid JSON rate, higher mean reward, varied tool selection.\n", "\n", "after_stats = inspect_generations(n_episodes=5, label='AFTER training')\n", "\n", "print('=== Before vs After comparison ===')\n", "print(f' Valid JSON rate: {before_stats[\"valid_json_rate\"]:.1%} -> {after_stats[\"valid_json_rate\"]:.1%}')\n", "print(f' Mean reward: {before_stats[\"mean_reward\"]:.4f} -> {after_stats[\"mean_reward\"]:.4f}')\n", "print(f' Tool diversity: {len(before_stats[\"tool_counts\"])} tools -> {len(after_stats[\"tool_counts\"])} tools')\n", "\n", "if after_stats['mean_reward'] > before_stats['mean_reward']:\n", " print(' ✅ Reward improved after training — learning signal working.')\n", "else:\n", " print(' ⚠️ Reward did not improve — check env connectivity and reward function.')" ] }, # ── Cell 9: Save model ──────────────────────────────────────────────────────── { "cell_type": "code", "id": "c9", "metadata": {}, "outputs": [], "execution_count": None, "source": [ "# Cell 9: Save model using Unsloth merge path\n", "#\n", "# CRITICAL: Do NOT use naive save_pretrained on a 4-bit model.\n", "# That upcasts to 16-bit then merges naively, which damages model quality.\n", "# Use save_pretrained_merged with save_method='merged_16bit' instead.\n", "\n", "SAVE_DIR = 'datacentric-grpo-final'\n", "\n", "model.save_pretrained_merged(\n", " SAVE_DIR,\n", " tokenizer,\n", " save_method='merged_16bit', # correct Unsloth merge path\n", ")\n", "\n", "print(f'Model saved to {SAVE_DIR}/')\n", "print('Test inference immediately to verify model quality before uploading.')\n", "\n", "# Quick inference test\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "test_tok = AutoTokenizer.from_pretrained(SAVE_DIR)\n", "test_model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, torch_dtype=torch.float16, device_map='auto')\n", "\n", "obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n", "inp = test_tok(build_prompt(obs), return_tensors='pt').to('cuda')\n", "with torch.no_grad():\n", " out = test_model.generate(**inp, max_new_tokens=80)\n", "response = test_tok.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n", "print(f'Post-save inference test: {response}')\n", "action = parse_action(response)\n", "print(f'Parsed action: {action}')\n", "print('Model save and inference verified.')" ] }, ] } with open("training/train.ipynb", "w") as f: json.dump(nb, f, indent=1) print("Notebook written successfully.")