Spaces:
Sleeping
Sleeping
| import json | |
| nb = { | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| "metadata": { | |
| "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, | |
| "language_info": {"name": "python", "version": "3.10.0"}, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| # ββ Cell 0: markdown intro ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "markdown", "id": "md0", "metadata": {}, | |
| "source": [ | |
| "# DataCentric-Env β GRPO Training Notebook\n\n", | |
| "**End-to-end runnable by a judge.** Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO via TRL + Unsloth.\n\n", | |
| "**Runtime required:** GPU (T4 or better). In Colab: `Runtime β Change runtime type β T4 GPU`\n\n", | |
| "## What this notebook does\n", | |
| "1. Installs Unsloth, TRL, and dependencies\n", | |
| "2. Loads Qwen2.5-3B-Instruct in 4-bit with LoRA adapters\n", | |
| "3. Builds a prompt dataset from live environment observations\n", | |
| "4. Defines a **reward function** that calls the live environment `/step` endpoint\n", | |
| "5. Trains with `GRPOTrainer` β the reward function IS the environment verifier\n", | |
| "6. Inspects sampled generations during training to catch reward hacking\n", | |
| "7. Saves the model using the **Unsloth merge path** (not naive `save_pretrained`)\n\n", | |
| "> **Grader note:** Format compliance is Grader 1, completely independent of accuracy.\n", | |
| "> All reward values are clamped to strictly `(0.001, 0.999)` β never 0.0 or 1.0." | |
| ] | |
| }, | |
| # ββ Cell 1: Install βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 1: Install dependencies\n", | |
| "# Unsloth install β use the correct pip source for Colab\n", | |
| "import subprocess, sys\n", | |
| "subprocess.run([sys.executable, '-m', 'pip', 'install',\n", | |
| " 'unsloth', 'trl>=0.12.0', 'transformers>=4.45.0',\n", | |
| " 'accelerate', 'peft', 'datasets', 'requests', 'bitsandbytes'\n", | |
| "], check=True)\n", | |
| "print('All dependencies installed.')" | |
| ] | |
| }, | |
| # ββ Cell 2: Imports + config βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 2: Imports and configuration\n", | |
| "import json, re, requests, torch\n", | |
| "from unsloth import FastLanguageModel\n", | |
| "from trl import GRPOTrainer, GRPOConfig\n", | |
| "from datasets import Dataset\n", | |
| "\n", | |
| "# Live environment URL β deployed on HuggingFace Spaces\n", | |
| "ENV_URL = 'https://aswini-kumar-datacentric-env.hf.space'\n", | |
| "\n", | |
| "# Verify the environment is reachable before training\n", | |
| "health = requests.get(f'{ENV_URL}/health', timeout=30)\n", | |
| "assert health.status_code == 200, f'Environment not reachable: {health.status_code}'\n", | |
| "print(f'Environment health: {health.json()}')\n", | |
| "\n", | |
| "SYSTEM_PROMPT = (\n", | |
| " 'You are a data quality agent. You receive dataset statistics and must choose '\n", | |
| " 'which specialist tool to call to improve the dataset so a downstream classifier '\n", | |
| " 'performs better.\\n\\n'\n", | |
| " 'Always respond with valid JSON in this exact format:\\n'\n", | |
| " '{\"agent\": \"<tool_name>\", \"target\": \"<column_or_all>\", \"strategy\": \"<strategy_name>\"}\\n\\n'\n", | |
| " 'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n", | |
| " 'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n", | |
| " 'Balancer strategies: undersample\\n'\n", | |
| " 'Relabeler: use when labels are noisy, costs 2 budget points.'\n", | |
| ")\n", | |
| "\n", | |
| "def build_prompt(obs: dict) -> str:\n", | |
| " return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n", | |
| "\n", | |
| "print('Config ready. ENV_URL:', ENV_URL)" | |
| ] | |
| }, | |
| # ββ Cell 3: Model setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 3: Load Qwen2.5-3B-Instruct in 4-bit with LoRA\n", | |
| "model, tokenizer = FastLanguageModel.from_pretrained(\n", | |
| " model_name='unsloth/Qwen2.5-3B-Instruct',\n", | |
| " max_seq_length=1024,\n", | |
| " load_in_4bit=True,\n", | |
| " dtype=None, # auto\n", | |
| ")\n", | |
| "\n", | |
| "model = FastLanguageModel.get_peft_model(\n", | |
| " model,\n", | |
| " r=16,\n", | |
| " lora_alpha=32,\n", | |
| " lora_dropout=0.0,\n", | |
| " target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n", | |
| " use_gradient_checkpointing='unsloth',\n", | |
| " random_state=42,\n", | |
| ")\n", | |
| "\n", | |
| "print(f'Model loaded. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')" | |
| ] | |
| }, | |
| # ββ Cell 4: Build prompt dataset βββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 4: Build prompt dataset\n", | |
| "# Collect N initial observations by calling /reset.\n", | |
| "# GRPOTrainer will generate multiple completions per prompt and score each via reward_fn.\n", | |
| "\n", | |
| "N_PROMPTS = 60 # number of distinct starting states\n", | |
| "prompts = []\n", | |
| "\n", | |
| "difficulties = ['easy'] * 30 + ['medium'] * 20 + ['hard'] * 10 # curriculum mix\n", | |
| "\n", | |
| "for i, diff in enumerate(difficulties):\n", | |
| " obs = requests.post(f'{ENV_URL}/reset', json={'difficulty': diff}, timeout=30).json()\n", | |
| " prompts.append({\n", | |
| " 'prompt': build_prompt(obs),\n", | |
| " 'difficulty': diff,\n", | |
| " 'initial_obs': json.dumps(obs),\n", | |
| " })\n", | |
| " if i % 10 == 0:\n", | |
| " print(f' Collected {i+1}/{N_PROMPTS} prompts (difficulty={diff})')\n", | |
| "\n", | |
| "dataset = Dataset.from_list(prompts)\n", | |
| "print(f'Dataset ready: {len(dataset)} prompts')\n", | |
| "print('Sample prompt (truncated):', dataset[0]['prompt'][:300])" | |
| ] | |
| }, | |
| # ββ Cell 5: Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 5: Reward function β this IS the environment verifier\n", | |
| "#\n", | |
| "# TRL GRPOTrainer calls reward_fn(completions, prompts, **kwargs) -> list[float]\n", | |
| "# For each completion:\n", | |
| "# 1. Parse the JSON action from the model output\n", | |
| "# 2. Call /reset to get a fresh env state\n", | |
| "# 3. Call /step with the parsed action\n", | |
| "# 4. Return the environment reward (already in (0.001, 0.999))\n", | |
| "#\n", | |
| "# Grader 1 (format compliance) fires FIRST, independently of all other graders.\n", | |
| "# Final reward is always clamped by the server before returning.\n", | |
| "\n", | |
| "def parse_action(text: str) -> dict:\n", | |
| " \"\"\"Extract JSON action from model output. Returns fallback on parse failure.\"\"\"\n", | |
| " text = text.strip()\n", | |
| " # Try to extract JSON block\n", | |
| " match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n", | |
| " if match:\n", | |
| " try:\n", | |
| " return json.loads(match.group(0))\n", | |
| " except Exception:\n", | |
| " pass\n", | |
| " # Fallback β intentionally malformed so Grader 1 fires 0.001\n", | |
| " return {'agent': '__invalid__'}\n", | |
| "\n", | |
| "\n", | |
| "def reward_fn(completions, prompts=None, **kwargs):\n", | |
| " \"\"\"\n", | |
| " Called by GRPOTrainer once per batch.\n", | |
| " Returns list[float] of rewards, one per completion.\n", | |
| " All values are strictly in (0.001, 0.999) β enforced by the server.\n", | |
| " \"\"\"\n", | |
| " rewards = []\n", | |
| " for completion in completions:\n", | |
| " try:\n", | |
| " action = parse_action(completion)\n", | |
| " # Fresh env state for each reward evaluation\n", | |
| " requests.post(f'{ENV_URL}/reset', json={}, timeout=20)\n", | |
| " result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n", | |
| " reward = float(result.get('reward', 0.001))\n", | |
| " # Safety clamp β should already be clamped by server\n", | |
| " reward = max(0.001, min(0.999, reward))\n", | |
| " except Exception as e:\n", | |
| " print(f' [reward_fn error] {e}')\n", | |
| " reward = 0.001 # format/connection failure -> minimum\n", | |
| " rewards.append(reward)\n", | |
| " return rewards\n", | |
| "\n", | |
| "\n", | |
| "# Quick sanity check before training\n", | |
| "test_completions = [\n", | |
| " '{\"agent\": \"cleaner\", \"target\": \"all\", \"strategy\": \"median_impute\"}',\n", | |
| " 'some random text that is not JSON',\n", | |
| " '{\"agent\": \"balancer\", \"strategy\": \"undersample\"}',\n", | |
| "]\n", | |
| "test_rewards = reward_fn(test_completions)\n", | |
| "print('Reward function sanity check:')\n", | |
| "for c, r in zip(test_completions, test_rewards):\n", | |
| " print(f' action={c[:60]!r:65s} -> reward={r:.4f}')\n", | |
| "\n", | |
| "assert all(0.0 < r < 1.0 for r in test_rewards), 'Reward out of valid range!'\n", | |
| "print('All rewards in (0.0, 1.0): PASS')" | |
| ] | |
| }, | |
| # ββ Cell 6: Generation inspection helper βββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 6: Generation inspection β reward hacking detection\n", | |
| "#\n", | |
| "# Run this BEFORE and AFTER training to verify:\n", | |
| "# - Agent produces valid JSON actions\n", | |
| "# - Reward rises because accuracy rises, NOT because of exploits\n", | |
| "# - Agent varies its tool selection based on observation\n", | |
| "\n", | |
| "def inspect_generations(n_episodes=5, label='baseline'):\n", | |
| " print(f'\\n=== Generation inspection: {label} ===')\n", | |
| " tool_counts = {}\n", | |
| " valid_json_count = 0\n", | |
| " rewards = []\n", | |
| "\n", | |
| " for ep in range(n_episodes):\n", | |
| " obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n", | |
| " prompt = build_prompt(obs)\n", | |
| "\n", | |
| " FastLanguageModel.for_inference(model)\n", | |
| " inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n", | |
| " with torch.no_grad():\n", | |
| " out = model.generate(**inputs, max_new_tokens=80, temperature=0.7, do_sample=True)\n", | |
| " FastLanguageModel.for_training(model)\n", | |
| "\n", | |
| " response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", | |
| " action = parse_action(response)\n", | |
| " agent = action.get('agent', '__invalid__')\n", | |
| "\n", | |
| " # Step in env\n", | |
| " result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n", | |
| " reward = result.get('reward', 0.0)\n", | |
| " acc_before = result.get('info', {}).get('prev_accuracy', '?')\n", | |
| " acc_after = result.get('info', {}).get('new_accuracy', '?')\n", | |
| "\n", | |
| " is_valid = isinstance(action.get('agent'), str) and action['agent'] in [\n", | |
| " 'cleaner', 'augmenter', 'balancer', 'relabeler', 'validator'\n", | |
| " ]\n", | |
| " if is_valid:\n", | |
| " valid_json_count += 1\n", | |
| " tool_counts[agent] = tool_counts.get(agent, 0) + 1\n", | |
| " rewards.append(reward)\n", | |
| "\n", | |
| " print(f' Ep{ep+1}: agent={agent:12s} reward={reward:.4f} '\n", | |
| " f'acc {acc_before}->{acc_after} '\n", | |
| " f'raw_output={response[:80]!r}')\n", | |
| "\n", | |
| " print(f'\\n Valid JSON rate: {valid_json_count}/{n_episodes}')\n", | |
| " print(f' Tool distribution: {tool_counts}')\n", | |
| " print(f' Mean reward: {sum(rewards)/len(rewards):.4f}')\n", | |
| " print()\n", | |
| "\n", | |
| " # Reward hacking checks\n", | |
| " if len(tool_counts) == 1:\n", | |
| " print(' β οΈ WARNING: Agent always picks same tool β possible reward hacking!')\n", | |
| " else:\n", | |
| " print(' β Tool diversity OK β no same-tool spamming detected.')\n", | |
| "\n", | |
| " return {'valid_json_rate': valid_json_count/n_episodes, 'mean_reward': sum(rewards)/len(rewards), 'tool_counts': tool_counts}\n", | |
| "\n", | |
| "\n", | |
| "# Run BEFORE training β establishes the baseline\n", | |
| "before_stats = inspect_generations(n_episodes=5, label='BEFORE training')" | |
| ] | |
| }, | |
| # ββ Cell 7: GRPO Training βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 7: GRPO Training\n", | |
| "#\n", | |
| "# GRPOTrainer:\n", | |
| "# - samples num_generations completions per prompt\n", | |
| "# - calls reward_fn on all completions\n", | |
| "# - updates model to increase probability of higher-reward completions\n", | |
| "#\n", | |
| "# The reward_fn IS the environment verifier β no learned reward model needed.\n", | |
| "\n", | |
| "config = GRPOConfig(\n", | |
| " output_dir='./datacentric-grpo',\n", | |
| " num_train_epochs=3,\n", | |
| " per_device_train_batch_size=1,\n", | |
| " gradient_accumulation_steps=4,\n", | |
| " num_generations=4, # 4 completions per prompt, ranked by reward\n", | |
| " max_completion_length=100, # action JSON is short\n", | |
| " learning_rate=5e-6,\n", | |
| " logging_steps=5,\n", | |
| " save_steps=50,\n", | |
| " warmup_ratio=0.1,\n", | |
| " lr_scheduler_type='cosine',\n", | |
| " report_to='none', # swap to 'wandb' for live curves\n", | |
| " remove_unused_columns=False,\n", | |
| ")\n", | |
| "\n", | |
| "trainer = GRPOTrainer(\n", | |
| " model=model,\n", | |
| " args=config,\n", | |
| " reward_funcs=reward_fn, # environment verifier IS the reward function\n", | |
| " train_dataset=dataset,\n", | |
| " tokenizer=tokenizer,\n", | |
| ")\n", | |
| "\n", | |
| "print('Starting GRPO training...')\n", | |
| "print(f' Dataset: {len(dataset)} prompts')\n", | |
| "print(f' Generations per prompt: {config.num_generations}')\n", | |
| "print(f' Total reward evaluations per epoch: {len(dataset) * config.num_generations}')\n", | |
| "print()\n", | |
| "\n", | |
| "trainer.train()\n", | |
| "\n", | |
| "print('Training complete.')" | |
| ] | |
| }, | |
| # ββ Cell 8: Post-training inspection βββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c8", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 8: Post-training generation inspection\n", | |
| "# Compare BEFORE vs AFTER to verify no reward hacking occurred.\n", | |
| "# Expected: higher valid JSON rate, higher mean reward, varied tool selection.\n", | |
| "\n", | |
| "after_stats = inspect_generations(n_episodes=5, label='AFTER training')\n", | |
| "\n", | |
| "print('=== Before vs After comparison ===')\n", | |
| "print(f' Valid JSON rate: {before_stats[\"valid_json_rate\"]:.1%} -> {after_stats[\"valid_json_rate\"]:.1%}')\n", | |
| "print(f' Mean reward: {before_stats[\"mean_reward\"]:.4f} -> {after_stats[\"mean_reward\"]:.4f}')\n", | |
| "print(f' Tool diversity: {len(before_stats[\"tool_counts\"])} tools -> {len(after_stats[\"tool_counts\"])} tools')\n", | |
| "\n", | |
| "if after_stats['mean_reward'] > before_stats['mean_reward']:\n", | |
| " print(' β Reward improved after training β learning signal working.')\n", | |
| "else:\n", | |
| " print(' β οΈ Reward did not improve β check env connectivity and reward function.')" | |
| ] | |
| }, | |
| # ββ Cell 9: Save model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| { | |
| "cell_type": "code", "id": "c9", "metadata": {}, "outputs": [], "execution_count": None, | |
| "source": [ | |
| "# Cell 9: Save model using Unsloth merge path\n", | |
| "#\n", | |
| "# CRITICAL: Do NOT use naive save_pretrained on a 4-bit model.\n", | |
| "# That upcasts to 16-bit then merges naively, which damages model quality.\n", | |
| "# Use save_pretrained_merged with save_method='merged_16bit' instead.\n", | |
| "\n", | |
| "SAVE_DIR = 'datacentric-grpo-final'\n", | |
| "\n", | |
| "model.save_pretrained_merged(\n", | |
| " SAVE_DIR,\n", | |
| " tokenizer,\n", | |
| " save_method='merged_16bit', # correct Unsloth merge path\n", | |
| ")\n", | |
| "\n", | |
| "print(f'Model saved to {SAVE_DIR}/')\n", | |
| "print('Test inference immediately to verify model quality before uploading.')\n", | |
| "\n", | |
| "# Quick inference test\n", | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
| "test_tok = AutoTokenizer.from_pretrained(SAVE_DIR)\n", | |
| "test_model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, torch_dtype=torch.float16, device_map='auto')\n", | |
| "\n", | |
| "obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n", | |
| "inp = test_tok(build_prompt(obs), return_tensors='pt').to('cuda')\n", | |
| "with torch.no_grad():\n", | |
| " out = test_model.generate(**inp, max_new_tokens=80)\n", | |
| "response = test_tok.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n", | |
| "print(f'Post-save inference test: {response}')\n", | |
| "action = parse_action(response)\n", | |
| "print(f'Parsed action: {action}')\n", | |
| "print('Model save and inference verified.')" | |
| ] | |
| }, | |
| ] | |
| } | |
| with open("training/train.ipynb", "w") as f: | |
| json.dump(nb, f, indent=1) | |
| print("Notebook written successfully.") | |