datacentric-env / training /make_notebook.py
Aswini-Kumar's picture
Upload training/make_notebook.py with huggingface_hub
cd02ce7 verified
import json
nb = {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python", "version": "3.10.0"},
"accelerator": "GPU"
},
"cells": [
# ── Cell 0: markdown intro ────────────────────────────────────────────────────
{
"cell_type": "markdown", "id": "md0", "metadata": {},
"source": [
"# DataCentric-Env β€” GRPO Training Notebook\n\n",
"**End-to-end runnable by a judge.** Trains Qwen2.5-3B-Instruct as a data quality agent using GRPO via TRL + Unsloth.\n\n",
"**Runtime required:** GPU (T4 or better). In Colab: `Runtime β†’ Change runtime type β†’ T4 GPU`\n\n",
"## What this notebook does\n",
"1. Installs Unsloth, TRL, and dependencies\n",
"2. Loads Qwen2.5-3B-Instruct in 4-bit with LoRA adapters\n",
"3. Builds a prompt dataset from live environment observations\n",
"4. Defines a **reward function** that calls the live environment `/step` endpoint\n",
"5. Trains with `GRPOTrainer` β€” the reward function IS the environment verifier\n",
"6. Inspects sampled generations during training to catch reward hacking\n",
"7. Saves the model using the **Unsloth merge path** (not naive `save_pretrained`)\n\n",
"> **Grader note:** Format compliance is Grader 1, completely independent of accuracy.\n",
"> All reward values are clamped to strictly `(0.001, 0.999)` β€” never 0.0 or 1.0."
]
},
# ── Cell 1: Install ───────────────────────────────────────────────────────────
{
"cell_type": "code", "id": "c1", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 1: Install dependencies\n",
"# Unsloth install β€” use the correct pip source for Colab\n",
"import subprocess, sys\n",
"subprocess.run([sys.executable, '-m', 'pip', 'install',\n",
" 'unsloth', 'trl>=0.12.0', 'transformers>=4.45.0',\n",
" 'accelerate', 'peft', 'datasets', 'requests', 'bitsandbytes'\n",
"], check=True)\n",
"print('All dependencies installed.')"
]
},
# ── Cell 2: Imports + config ─────────────────────────────────────────────────
{
"cell_type": "code", "id": "c2", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 2: Imports and configuration\n",
"import json, re, requests, torch\n",
"from unsloth import FastLanguageModel\n",
"from trl import GRPOTrainer, GRPOConfig\n",
"from datasets import Dataset\n",
"\n",
"# Live environment URL β€” deployed on HuggingFace Spaces\n",
"ENV_URL = 'https://aswini-kumar-datacentric-env.hf.space'\n",
"\n",
"# Verify the environment is reachable before training\n",
"health = requests.get(f'{ENV_URL}/health', timeout=30)\n",
"assert health.status_code == 200, f'Environment not reachable: {health.status_code}'\n",
"print(f'Environment health: {health.json()}')\n",
"\n",
"SYSTEM_PROMPT = (\n",
" 'You are a data quality agent. You receive dataset statistics and must choose '\n",
" 'which specialist tool to call to improve the dataset so a downstream classifier '\n",
" 'performs better.\\n\\n'\n",
" 'Always respond with valid JSON in this exact format:\\n'\n",
" '{\"agent\": \"<tool_name>\", \"target\": \"<column_or_all>\", \"strategy\": \"<strategy_name>\"}\\n\\n'\n",
" 'Available tools: cleaner, augmenter, balancer, relabeler, validator\\n'\n",
" 'Cleaner strategies: median_impute, mean_impute, drop_rows\\n'\n",
" 'Balancer strategies: undersample\\n'\n",
" 'Relabeler: use when labels are noisy, costs 2 budget points.'\n",
")\n",
"\n",
"def build_prompt(obs: dict) -> str:\n",
" return SYSTEM_PROMPT + '\\n\\nCurrent state:\\n' + json.dumps(obs, indent=2) + '\\n\\nYour action:'\n",
"\n",
"print('Config ready. ENV_URL:', ENV_URL)"
]
},
# ── Cell 3: Model setup ───────────────────────────────────────────────────────
{
"cell_type": "code", "id": "c3", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 3: Load Qwen2.5-3B-Instruct in 4-bit with LoRA\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name='unsloth/Qwen2.5-3B-Instruct',\n",
" max_seq_length=1024,\n",
" load_in_4bit=True,\n",
" dtype=None, # auto\n",
")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=16,\n",
" lora_alpha=32,\n",
" lora_dropout=0.0,\n",
" target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
" use_gradient_checkpointing='unsloth',\n",
" random_state=42,\n",
")\n",
"\n",
"print(f'Model loaded. Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
]
},
# ── Cell 4: Build prompt dataset ─────────────────────────────────────────────
{
"cell_type": "code", "id": "c4", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 4: Build prompt dataset\n",
"# Collect N initial observations by calling /reset.\n",
"# GRPOTrainer will generate multiple completions per prompt and score each via reward_fn.\n",
"\n",
"N_PROMPTS = 60 # number of distinct starting states\n",
"prompts = []\n",
"\n",
"difficulties = ['easy'] * 30 + ['medium'] * 20 + ['hard'] * 10 # curriculum mix\n",
"\n",
"for i, diff in enumerate(difficulties):\n",
" obs = requests.post(f'{ENV_URL}/reset', json={'difficulty': diff}, timeout=30).json()\n",
" prompts.append({\n",
" 'prompt': build_prompt(obs),\n",
" 'difficulty': diff,\n",
" 'initial_obs': json.dumps(obs),\n",
" })\n",
" if i % 10 == 0:\n",
" print(f' Collected {i+1}/{N_PROMPTS} prompts (difficulty={diff})')\n",
"\n",
"dataset = Dataset.from_list(prompts)\n",
"print(f'Dataset ready: {len(dataset)} prompts')\n",
"print('Sample prompt (truncated):', dataset[0]['prompt'][:300])"
]
},
# ── Cell 5: Reward function ───────────────────────────────────────────────────
{
"cell_type": "code", "id": "c5", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 5: Reward function β€” this IS the environment verifier\n",
"#\n",
"# TRL GRPOTrainer calls reward_fn(completions, prompts, **kwargs) -> list[float]\n",
"# For each completion:\n",
"# 1. Parse the JSON action from the model output\n",
"# 2. Call /reset to get a fresh env state\n",
"# 3. Call /step with the parsed action\n",
"# 4. Return the environment reward (already in (0.001, 0.999))\n",
"#\n",
"# Grader 1 (format compliance) fires FIRST, independently of all other graders.\n",
"# Final reward is always clamped by the server before returning.\n",
"\n",
"def parse_action(text: str) -> dict:\n",
" \"\"\"Extract JSON action from model output. Returns fallback on parse failure.\"\"\"\n",
" text = text.strip()\n",
" # Try to extract JSON block\n",
" match = re.search(r'\\{[^}]+\\}', text, re.DOTALL)\n",
" if match:\n",
" try:\n",
" return json.loads(match.group(0))\n",
" except Exception:\n",
" pass\n",
" # Fallback β€” intentionally malformed so Grader 1 fires 0.001\n",
" return {'agent': '__invalid__'}\n",
"\n",
"\n",
"def reward_fn(completions, prompts=None, **kwargs):\n",
" \"\"\"\n",
" Called by GRPOTrainer once per batch.\n",
" Returns list[float] of rewards, one per completion.\n",
" All values are strictly in (0.001, 0.999) β€” enforced by the server.\n",
" \"\"\"\n",
" rewards = []\n",
" for completion in completions:\n",
" try:\n",
" action = parse_action(completion)\n",
" # Fresh env state for each reward evaluation\n",
" requests.post(f'{ENV_URL}/reset', json={}, timeout=20)\n",
" result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
" reward = float(result.get('reward', 0.001))\n",
" # Safety clamp β€” should already be clamped by server\n",
" reward = max(0.001, min(0.999, reward))\n",
" except Exception as e:\n",
" print(f' [reward_fn error] {e}')\n",
" reward = 0.001 # format/connection failure -> minimum\n",
" rewards.append(reward)\n",
" return rewards\n",
"\n",
"\n",
"# Quick sanity check before training\n",
"test_completions = [\n",
" '{\"agent\": \"cleaner\", \"target\": \"all\", \"strategy\": \"median_impute\"}',\n",
" 'some random text that is not JSON',\n",
" '{\"agent\": \"balancer\", \"strategy\": \"undersample\"}',\n",
"]\n",
"test_rewards = reward_fn(test_completions)\n",
"print('Reward function sanity check:')\n",
"for c, r in zip(test_completions, test_rewards):\n",
" print(f' action={c[:60]!r:65s} -> reward={r:.4f}')\n",
"\n",
"assert all(0.0 < r < 1.0 for r in test_rewards), 'Reward out of valid range!'\n",
"print('All rewards in (0.0, 1.0): PASS')"
]
},
# ── Cell 6: Generation inspection helper ─────────────────────────────────────
{
"cell_type": "code", "id": "c6", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 6: Generation inspection β€” reward hacking detection\n",
"#\n",
"# Run this BEFORE and AFTER training to verify:\n",
"# - Agent produces valid JSON actions\n",
"# - Reward rises because accuracy rises, NOT because of exploits\n",
"# - Agent varies its tool selection based on observation\n",
"\n",
"def inspect_generations(n_episodes=5, label='baseline'):\n",
" print(f'\\n=== Generation inspection: {label} ===')\n",
" tool_counts = {}\n",
" valid_json_count = 0\n",
" rewards = []\n",
"\n",
" for ep in range(n_episodes):\n",
" obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
" prompt = build_prompt(obs)\n",
"\n",
" FastLanguageModel.for_inference(model)\n",
" inputs = tokenizer(prompt, return_tensors='pt').to('cuda')\n",
" with torch.no_grad():\n",
" out = model.generate(**inputs, max_new_tokens=80, temperature=0.7, do_sample=True)\n",
" FastLanguageModel.for_training(model)\n",
"\n",
" response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
" action = parse_action(response)\n",
" agent = action.get('agent', '__invalid__')\n",
"\n",
" # Step in env\n",
" result = requests.post(f'{ENV_URL}/step', json=action, timeout=20).json()\n",
" reward = result.get('reward', 0.0)\n",
" acc_before = result.get('info', {}).get('prev_accuracy', '?')\n",
" acc_after = result.get('info', {}).get('new_accuracy', '?')\n",
"\n",
" is_valid = isinstance(action.get('agent'), str) and action['agent'] in [\n",
" 'cleaner', 'augmenter', 'balancer', 'relabeler', 'validator'\n",
" ]\n",
" if is_valid:\n",
" valid_json_count += 1\n",
" tool_counts[agent] = tool_counts.get(agent, 0) + 1\n",
" rewards.append(reward)\n",
"\n",
" print(f' Ep{ep+1}: agent={agent:12s} reward={reward:.4f} '\n",
" f'acc {acc_before}->{acc_after} '\n",
" f'raw_output={response[:80]!r}')\n",
"\n",
" print(f'\\n Valid JSON rate: {valid_json_count}/{n_episodes}')\n",
" print(f' Tool distribution: {tool_counts}')\n",
" print(f' Mean reward: {sum(rewards)/len(rewards):.4f}')\n",
" print()\n",
"\n",
" # Reward hacking checks\n",
" if len(tool_counts) == 1:\n",
" print(' ⚠️ WARNING: Agent always picks same tool β€” possible reward hacking!')\n",
" else:\n",
" print(' βœ… Tool diversity OK β€” no same-tool spamming detected.')\n",
"\n",
" return {'valid_json_rate': valid_json_count/n_episodes, 'mean_reward': sum(rewards)/len(rewards), 'tool_counts': tool_counts}\n",
"\n",
"\n",
"# Run BEFORE training β€” establishes the baseline\n",
"before_stats = inspect_generations(n_episodes=5, label='BEFORE training')"
]
},
# ── Cell 7: GRPO Training ─────────────────────────────────────────────────────
{
"cell_type": "code", "id": "c7", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 7: GRPO Training\n",
"#\n",
"# GRPOTrainer:\n",
"# - samples num_generations completions per prompt\n",
"# - calls reward_fn on all completions\n",
"# - updates model to increase probability of higher-reward completions\n",
"#\n",
"# The reward_fn IS the environment verifier β€” no learned reward model needed.\n",
"\n",
"config = GRPOConfig(\n",
" output_dir='./datacentric-grpo',\n",
" num_train_epochs=3,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" num_generations=4, # 4 completions per prompt, ranked by reward\n",
" max_completion_length=100, # action JSON is short\n",
" learning_rate=5e-6,\n",
" logging_steps=5,\n",
" save_steps=50,\n",
" warmup_ratio=0.1,\n",
" lr_scheduler_type='cosine',\n",
" report_to='none', # swap to 'wandb' for live curves\n",
" remove_unused_columns=False,\n",
")\n",
"\n",
"trainer = GRPOTrainer(\n",
" model=model,\n",
" args=config,\n",
" reward_funcs=reward_fn, # environment verifier IS the reward function\n",
" train_dataset=dataset,\n",
" tokenizer=tokenizer,\n",
")\n",
"\n",
"print('Starting GRPO training...')\n",
"print(f' Dataset: {len(dataset)} prompts')\n",
"print(f' Generations per prompt: {config.num_generations}')\n",
"print(f' Total reward evaluations per epoch: {len(dataset) * config.num_generations}')\n",
"print()\n",
"\n",
"trainer.train()\n",
"\n",
"print('Training complete.')"
]
},
# ── Cell 8: Post-training inspection ─────────────────────────────────────────
{
"cell_type": "code", "id": "c8", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 8: Post-training generation inspection\n",
"# Compare BEFORE vs AFTER to verify no reward hacking occurred.\n",
"# Expected: higher valid JSON rate, higher mean reward, varied tool selection.\n",
"\n",
"after_stats = inspect_generations(n_episodes=5, label='AFTER training')\n",
"\n",
"print('=== Before vs After comparison ===')\n",
"print(f' Valid JSON rate: {before_stats[\"valid_json_rate\"]:.1%} -> {after_stats[\"valid_json_rate\"]:.1%}')\n",
"print(f' Mean reward: {before_stats[\"mean_reward\"]:.4f} -> {after_stats[\"mean_reward\"]:.4f}')\n",
"print(f' Tool diversity: {len(before_stats[\"tool_counts\"])} tools -> {len(after_stats[\"tool_counts\"])} tools')\n",
"\n",
"if after_stats['mean_reward'] > before_stats['mean_reward']:\n",
" print(' βœ… Reward improved after training β€” learning signal working.')\n",
"else:\n",
" print(' ⚠️ Reward did not improve β€” check env connectivity and reward function.')"
]
},
# ── Cell 9: Save model ────────────────────────────────────────────────────────
{
"cell_type": "code", "id": "c9", "metadata": {}, "outputs": [], "execution_count": None,
"source": [
"# Cell 9: Save model using Unsloth merge path\n",
"#\n",
"# CRITICAL: Do NOT use naive save_pretrained on a 4-bit model.\n",
"# That upcasts to 16-bit then merges naively, which damages model quality.\n",
"# Use save_pretrained_merged with save_method='merged_16bit' instead.\n",
"\n",
"SAVE_DIR = 'datacentric-grpo-final'\n",
"\n",
"model.save_pretrained_merged(\n",
" SAVE_DIR,\n",
" tokenizer,\n",
" save_method='merged_16bit', # correct Unsloth merge path\n",
")\n",
"\n",
"print(f'Model saved to {SAVE_DIR}/')\n",
"print('Test inference immediately to verify model quality before uploading.')\n",
"\n",
"# Quick inference test\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"test_tok = AutoTokenizer.from_pretrained(SAVE_DIR)\n",
"test_model = AutoModelForCausalLM.from_pretrained(SAVE_DIR, torch_dtype=torch.float16, device_map='auto')\n",
"\n",
"obs = requests.post(f'{ENV_URL}/reset', json={}, timeout=30).json()\n",
"inp = test_tok(build_prompt(obs), return_tensors='pt').to('cuda')\n",
"with torch.no_grad():\n",
" out = test_model.generate(**inp, max_new_tokens=80)\n",
"response = test_tok.decode(out[0][inp['input_ids'].shape[1]:], skip_special_tokens=True)\n",
"print(f'Post-save inference test: {response}')\n",
"action = parse_action(response)\n",
"print(f'Parsed action: {action}')\n",
"print('Model save and inference verified.')"
]
},
]
}
with open("training/train.ipynb", "w") as f:
json.dump(nb, f, indent=1)
print("Notebook written successfully.")