{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python", "version": "3.10" }, "accelerator": "GPU", "colab": { "provenance": [], "gpuType": "T4" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ZXq1I39Jxz1q" }, "source": [ "# ExecAssist — GRPO Training on Colab T4\n", "\n", "Trains Qwen2.5-0.5B-Instruct with TRL GRPO on the deployed ExecAssist environment.\n", "\n", "**Runtime:** Set runtime → T4 GPU. Total run ~45 min.\n", "\n", "**Outputs:** `training_results.png` + `results.json` for your README." ], "id": "ZXq1I39Jxz1q" }, { "cell_type": "markdown", "metadata": { "id": "p1d4Lkp_xz1r" }, "source": [ "## 1. Install + GPU check" ], "id": "p1d4Lkp_xz1r" }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fXDfPHQvxz1r", "outputId": "afb1898b-7db8-4783-8b5d-619d9bbd600e" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.8/52.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m530.7/530.7 MB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m366.1/366.1 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m169.9/169.9 MB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.5/196.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.4/60.4 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.3/188.3 MB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m423.1/423.1 MB\u001b[0m \u001b[31m824.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.7/10.7 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.2/90.2 MB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m81.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m214.1/214.1 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m47.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.5/59.5 MB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.9/200.9 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m145.9/145.9 MB\u001b[0m \u001b[31m131.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m" ] } ], "source": [ "!pip uninstall -y torchvision torchaudio -q\n", "!pip install -q -U torch transformers trl datasets accelerate matplotlib" ], "id": "fXDfPHQvxz1r" }, { "cell_type": "code", "metadata": { "id": "QOqVp9nrxz1r" }, "execution_count": null, "outputs": [], "source": [ "import torch\n", "assert torch.cuda.is_available(), 'No GPU detected — set Runtime > Change runtime type > T4 GPU'\n", "print('GPU:', torch.cuda.get_device_name(0))\n", "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')" ], "id": "QOqVp9nrxz1r" }, { "cell_type": "markdown", "metadata": { "id": "NeEF7A1mxz1s" }, "source": [ "## 2. Config" ], "id": "NeEF7A1mxz1s" }, { "cell_type": "code", "metadata": { "id": "YcYdPAMNxz1s" }, "execution_count": null, "outputs": [], "source": [ "import os, json, requests, copy\n", "from datetime import datetime\n", "\n", "ENV_URL = 'https://devanshudon-exec-assist.hf.space'\n", "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n", "N_PER_TASK = 30 # 30 each = 90 scenarios total\n", "N_EVAL = 10 # baseline/trained eval samples per task\n", "OUT_DIR = 'training_logs'\n", "os.makedirs(OUT_DIR, exist_ok=True)" ], "id": "YcYdPAMNxz1s" }, { "cell_type": "markdown", "metadata": { "id": "2ODWK44mxz1s" }, "source": [ "## 3. Collect scenarios from your HF Space\n", "\n", "Each `/reset` returns a fresh scenario. We pull 90 of them once and reuse them for both training and evaluation — this keeps reward scoring deterministic." ], "id": "2ODWK44mxz1s" }, { "cell_type": "code", "metadata": { "id": "_BpUrmBOxz1s" }, "execution_count": null, "outputs": [], "source": [ "def reset_env(task):\n", " r = requests.post(f'{ENV_URL}/reset', params={'task': task}, timeout=60)\n", " r.raise_for_status()\n", " return r.json()\n", "\n", "scenarios = []\n", "for task in ['easy', 'medium', 'hard']:\n", " for i in range(N_PER_TASK):\n", " try:\n", " data = reset_env(task)\n", " scenarios.append({'task': task, 'observation': data.get('observation', {})})\n", " except Exception as e:\n", " print(f' ! {task}#{i}: {e}')\n", " print(f' ✓ {task}: {sum(1 for s in scenarios if s[\"task\"]==task)} collected')\n", "\n", "print(f'\\nTotal scenarios: {len(scenarios)}')" ], "id": "_BpUrmBOxz1s" }, { "cell_type": "markdown", "metadata": { "id": "hjRq6oNfxz1s" }, "source": [ "## 4. Reward scoring (heuristic, deterministic, no API calls)\n", "\n", "Mirrors the env's weightings — Easy 50/50, Medium 30/40/30, Hard 34/33/33 — without the LLM judge so training stays fast." ], "id": "hjRq6oNfxz1s" }, { "cell_type": "code", "metadata": { "id": "nprQokwPxz1s" }, "execution_count": null, "outputs": [], "source": [ "def score_email(email):\n", " if not email or len(email.strip()) < 20: return 0.0\n", " e, score = email.lower(), 0.5\n", " if any(w in e for w in ['thank','appreciate','please','regards','sincerely','best']): score += 0.15\n", " if any(e.lstrip().startswith(g) for g in ['hi ','hello ','dear ','good morning','good afternoon']): score += 0.10\n", " if any(c in e for c in ['regards','sincerely','best,','thanks,','cheers']): score += 0.10\n", " wc = len(email.split())\n", " if 25 <= wc <= 200: score += 0.15\n", " elif wc > 250: score -= 0.10\n", " elif wc < 15: score -= 0.20\n", " return max(0.0, min(1.0, score))\n", "\n", "def score_scheduling(meeting, obs):\n", " if not isinstance(meeting, dict): return 0.0\n", " has_s = 'start_time' in meeting; has_e = 'end_time' in meeting\n", " has_p = isinstance(meeting.get('participants'), list) and len(meeting.get('participants',[])) >= 2\n", " has_subj = bool(meeting.get('subject'))\n", " no_conflict = in_hours = good_dur = False\n", " if has_s and has_e:\n", " try:\n", " ns = datetime.fromisoformat(meeting['start_time'])\n", " ne = datetime.fromisoformat(meeting['end_time'])\n", " in_hours = 9 <= ns.hour < 17\n", " d = (ne - ns).total_seconds()/60\n", " good_dur = 15 <= d <= 120\n", " no_conflict = True\n", " for m in obs.get('calendar', {}).get('existing_meetings', []):\n", " try:\n", " es = datetime.fromisoformat(m['start_time'])\n", " ee = datetime.fromisoformat(m['end_time'])\n", " if not (ne <= es or ns >= ee):\n", " no_conflict = False; break\n", " except: pass\n", " except: pass\n", " checks = [has_s, has_e, has_p, has_subj, no_conflict, in_hours, good_dur]\n", " return sum(checks) / len(checks)\n", "\n", "def score_conflict(action, obs):\n", " cal_action = action.get('calendar_action','')\n", " meeting = action.get('meeting_details') or {}\n", " email = (action.get('email_reply') or '').lower()\n", " score = 0.5\n", " if any(w in email for w in ['conflict','unavailable','alternative','instead','however','unfortunately']): score += 0.15\n", " alts = meeting.get('proposed_alternatives') if isinstance(meeting, dict) else None\n", " if alts and isinstance(alts, list) and len(alts) >= 2: score += 0.25\n", " if cal_action == 'propose_alternatives': score += 0.10\n", " return max(0.0, min(1.0, score))\n", "\n", "def penalties(action):\n", " p = 0.0\n", " email = action.get('email_reply','') or ''\n", " if not email or len(email.strip()) < 20: p -= 0.30\n", " if len(email) > 1500: p -= 0.15\n", " if not action.get('meeting_details'): p -= 0.40\n", " return p\n", "\n", "def compose_reward(action, obs, task):\n", " if not isinstance(action, dict): return 0.0\n", " e = score_email(action.get('email_reply',''))\n", " s = score_scheduling(action.get('meeting_details'), obs)\n", " c = score_conflict(action, obs)\n", " p = penalties(action)\n", " if task == 'easy': base = 0.50*e + 0.50*s\n", " elif task == 'medium': base = 0.30*e + 0.40*c + 0.30*s\n", " else: base = 0.34*e + 0.33*s + 0.33*c\n", " return max(0.0, min(1.0, base + p))\n", "\n", "# Sanity check\n", "demo = {'email_reply':'Hi Sarah, Thank you for reaching out. Monday at 2pm works well. Best regards, Alex Chen',\n", " 'calendar_action':'book',\n", " 'meeting_details':{'participants':['s@x.com','a@x.com'],\n", " 'start_time':'2026-04-28T14:00:00',\n", " 'end_time':'2026-04-28T14:30:00',\n", " 'subject':'Test'}}\n", "print('Demo reward (easy):', compose_reward(demo, scenarios[0]['observation'], 'easy'))" ], "id": "nprQokwPxz1s" }, { "cell_type": "markdown", "metadata": { "id": "4nS43nVkxz1s" }, "source": [ "## 5. Build prompts + HF Dataset" ], "id": "4nS43nVkxz1s" }, { "cell_type": "code", "metadata": { "id": "rfhuu5Ebxz1s" }, "execution_count": null, "outputs": [], "source": [ "def build_prompt(obs):\n", " emails = obs.get('emails', [])\n", " cal = obs.get('calendar', {})\n", " em_str = ''\n", " for e in emails:\n", " em_str += f\"From: {e.get('sender','?')}\\nSubject: {e.get('subject','?')}\\nPriority: {e.get('priority','normal')}\\n{e.get('body','')}\\n---\\n\"\n", " meetings = cal.get('existing_meetings', [])\n", " cal_str = 'Existing meetings:\\n' + ('\\n'.join(f\" - {m.get('subject','?')}: {m.get('start_time','?')} -> {m.get('end_time','?')}\" for m in meetings) if meetings else ' (none)')\n", " exec_name = cal.get('executive_name','Alex Chen')\n", " return f\"\"\"You are {exec_name}'s executive assistant. Reply to incoming email and book the meeting.\n", "\n", "EMAILS:\n", "{em_str}\n", "\n", "{cal_str}\n", "\n", "Working hours: 9-17. Duration: 15-120 min. Avoid double-booking.\n", "\n", "Respond with ONLY a JSON object, no commentary:\n", "{{\\\"email_reply\\\":\\\"...\\\",\\\"calendar_action\\\":\\\"book\\\",\\\"meeting_details\\\":{{\\\"participants\\\":[\\\"a@x.com\\\",\\\"b@x.com\\\"],\\\"start_time\\\":\\\"2026-04-28T14:00:00\\\",\\\"end_time\\\":\\\"2026-04-28T14:30:00\\\",\\\"subject\\\":\\\"...\\\"}}}}\"\"\"\n", "\n", "from datasets import Dataset\n", "ds = Dataset.from_dict({\n", " 'prompt': [build_prompt(s['observation']) for s in scenarios],\n", " 'task': [s['task'] for s in scenarios],\n", " 'scenario': [s['observation'] for s in scenarios],\n", "})\n", "print(f'Dataset: {len(ds)} rows')\n", "print('---\\nSample prompt (truncated):\\n', ds[0]['prompt'][:400], '...')" ], "id": "rfhuu5Ebxz1s" }, { "cell_type": "markdown", "metadata": { "id": "9ztlxmy3xz1t" }, "source": [ "## 6. JSON parser + GRPO reward function\n", "\n", "GRPOTrainer passes extra dataset columns (`task`, `scenario`) as kwargs to the reward function." ], "id": "9ztlxmy3xz1t" }, { "cell_type": "code", "metadata": { "id": "OxPpcvbzxz1t" }, "execution_count": null, "outputs": [], "source": [ "def parse_action(text):\n", " try:\n", " s = text.find('{')\n", " if s == -1: return {}\n", " depth = 0\n", " for i, ch in enumerate(text[s:], start=s):\n", " if ch == '{': depth += 1\n", " elif ch == '}':\n", " depth -= 1\n", " if depth == 0:\n", " return json.loads(text[s:i+1])\n", " except Exception:\n", " return {}\n", " return {}\n", "\n", "def reward_function(completions, scenario, task, **kwargs):\n", " rewards = []\n", " for comp, scen, t in zip(completions, scenario, task):\n", " text = comp[-1]['content'] if isinstance(comp, list) else comp\n", " action = parse_action(text)\n", " try:\n", " r = compose_reward(action, scen, t)\n", " except Exception:\n", " r = 0.0\n", " rewards.append(float(r))\n", " return rewards\n", "\n", "# Smoke test\n", "demo_comp = json.dumps(demo)\n", "print('Reward fn output:', reward_function([demo_comp], [scenarios[0]['observation']], ['easy']))" ], "id": "OxPpcvbzxz1t" }, { "cell_type": "markdown", "metadata": { "id": "fIDady-Oxz1t" }, "source": [ "## 7. Load model + tokenizer" ], "id": "fIDady-Oxz1t" }, { "cell_type": "code", "metadata": { "id": "U-GrWh3uxz1t" }, "execution_count": null, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", "# Load in fp16 for faster download, Cell 9 casts to float32 for stable GRPO training\n", "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to('cuda')\n", "print(f'Loaded {MODEL_NAME} — {sum(p.numel() for p in model.parameters())/1e6:.0f}M params')" ], "id": "U-GrWh3uxz1t" }, { "cell_type": "markdown", "metadata": { "id": "yWlKwQnaxz1t" }, "source": [ "## 8. Baseline evaluation (untrained Qwen-0.5B)" ], "id": "yWlKwQnaxz1t" }, { "cell_type": "code", "metadata": { "id": "8aM2JB7pxz1t" }, "execution_count": null, "outputs": [], "source": [ "def evaluate(m, tok, dataset, n_per_task=10):\n", " m.eval()\n", " results = {'easy':[], 'medium':[], 'hard':[]}\n", " for ex in dataset:\n", " t = ex['task']\n", " if len(results[t]) >= n_per_task: continue\n", " msgs = [{'role':'user','content': ex['prompt']}]\n", " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n", " ids = tok(text, return_tensors='pt', truncation=True, max_length=1024).to('cuda')\n", " with torch.no_grad():\n", " out = m.generate(**ids, max_new_tokens=350, do_sample=True, temperature=0.7,\n", " top_p=0.9, pad_token_id=tok.eos_token_id)\n", " comp = tok.decode(out[0][ids['input_ids'].shape[1]:], skip_special_tokens=True)\n", " action = parse_action(comp)\n", " r = compose_reward(action, ex['scenario'], t) if action else 0.0\n", " results[t].append(r)\n", " return results\n", "\n", "print('Evaluating baseline...')\n", "baseline = evaluate(model, tokenizer, ds, n_per_task=N_EVAL)\n", "for t, rs in baseline.items():\n", " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')" ], "id": "8aM2JB7pxz1t" }, { "cell_type": "markdown", "metadata": { "id": "QbsM0OX1xz1t" }, "source": [ "## 9. Train with TRL GRPO\n", "\n", "Hyperparameters chosen for T4 + 0.5B model. Should complete in ~25-35 min." ], "id": "QbsM0OX1xz1t" }, { "cell_type": "code", "source": [ "# Cast model to float32 for stable GRPO training\n", "model = model.float()\n", "print(\"Model dtype:\", next(model.parameters()).dtype) # should print: torch.float32" ], "metadata": { "id": "aWkbPpIf3SBN" }, "id": "aWkbPpIf3SBN", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Reload clean model (undo bad training)\n", "from transformers import AutoModelForCausalLM\n", "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).float()\n", "print(\"✅ Clean model reloaded\")" ], "metadata": { "id": "uSDzeaKu9vvg" }, "id": "uSDzeaKu9vvg", "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Experimental tracking setup (W&B)\n", "# Run this cell first, then paste your W&B API key when prompted\n", "# Get your free API key at: https://wandb.ai/authorize\n", "!pip install wandb -q\n", "import wandb\n", "wandb.login()" ], "metadata": { "id": "y8iDa_25l7c0" }, "id": "y8iDa_25l7c0", "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "e2bJ5jmUxz1t" }, "execution_count": null, "outputs": [], "source": [ "from trl import GRPOTrainer, GRPOConfig\n", "\n", "cfg = GRPOConfig(\n", " output_dir='./grpo_out',\n", " learning_rate=1e-6, # ← Slower learning (was 5e-6, too aggressive)\n", " per_device_train_batch_size=1, # was 2, reduces OOM risk on T4\n", " gradient_accumulation_steps=8, # was 4, keeps effective batch size the same\n", " num_generations=8, # ← More variety per step (was 4, too few)\n", " num_train_epochs=3, # ← Train longer (was 1, too short)\n", " logging_steps=1,\n", " save_strategy='no',\n", " fp16=False,\n", " bf16=False,\n", " beta=0.1, # ← KL penalty - stops the collapse (was missing)\n", " report_to='wandb', # experimental tracking enabled (run `wandb login` first)\n", "\n", " gradient_checkpointing=True,\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " args=cfg,\n", " train_dataset=ds,\n", " reward_funcs=reward_function,\n", " processing_class=tokenizer,\n", ")\n", "\n", "print(\"Starting GRPO training (fixed)...\")\n", "trainer.train()\n", "print(\"✅ Training complete\")" ], "id": "e2bJ5jmUxz1t" }, { "cell_type": "markdown", "metadata": { "id": "Zllt81y0xz1t" }, "source": [ "## 10. Post-training evaluation" ], "id": "Zllt81y0xz1t" }, { "cell_type": "code", "metadata": { "id": "qikGbS98xz1t" }, "execution_count": null, "outputs": [], "source": [ "print('Evaluating trained model...')\n", "trained = evaluate(trainer.model, tokenizer, ds, n_per_task=N_EVAL)\n", "for t, rs in trained.items():\n", " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')\n", "\n", "# Compare\n", "print('\\n Baseline → Trained')\n", "for t in ['easy','medium','hard']:\n", " b = sum(baseline[t])/max(len(baseline[t]),1)\n", " tr = sum(trained[t])/max(len(trained[t]),1)\n", " delta = (tr - b) / max(b, 1e-6) * 100\n", " print(f' {t:<7} {b:.3f} → {tr:.3f} ({delta:+.1f}%)')" ], "id": "qikGbS98xz1t" }, { "cell_type": "markdown", "metadata": { "id": "FuVJHC5Dxz1t" }, "source": [ "## 11. Plots + save" ], "id": "FuVJHC5Dxz1t" }, { "cell_type": "code", "metadata": { "id": "yZaNvORKxz1t" }, "execution_count": null, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "tasks = ['easy','medium','hard']\n", "b_avg = [np.mean(baseline[t]) for t in tasks]\n", "t_avg = [np.mean(trained[t]) for t in tasks]\n", "\n", "log = trainer.state.log_history\n", "train_steps = [e.get('step') for e in log if 'reward' in e]\n", "train_rewards = [e.get('reward') for e in log if 'reward' in e]\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Bar: baseline vs trained\n", "x = np.arange(len(tasks)); w = 0.35\n", "axes[0].bar(x - w/2, b_avg, w, label='Baseline (Qwen2.5-0.5B, untrained)', color='#e74c3c', alpha=0.85)\n", "axes[0].bar(x + w/2, t_avg, w, label='Trained (GRPO)', color='#2ecc71', alpha=0.85)\n", "axes[0].set_xticks(x); axes[0].set_xticklabels([t.capitalize() for t in tasks])\n", "axes[0].set_ylabel('Mean reward'); axes[0].set_title('Baseline vs Trained — per task')\n", "axes[0].legend(); axes[0].grid(axis='y', alpha=0.3); axes[0].set_ylim(0, 1)\n", "for i, (b, tr) in enumerate(zip(b_avg, t_avg)):\n", " axes[0].text(i - w/2, b + 0.02, f'{b:.3f}', ha='center', fontsize=9)\n", " axes[0].text(i + w/2, tr + 0.02, f'{tr:.3f}', ha='center', fontsize=9)\n", "\n", "# Line: training reward over steps\n", "if train_rewards:\n", " axes[1].plot(train_steps, train_rewards, marker='o', color='#3498db', linewidth=2, markersize=4)\n", " axes[1].set_xlabel('Training step'); axes[1].set_ylabel('Mean reward (batch)')\n", " axes[1].set_title('GRPO — reward over training steps')\n", " axes[1].grid(alpha=0.3)\n", "else:\n", " axes[1].text(0.5, 0.5, 'No training log captured', ha='center', va='center', transform=axes[1].transAxes)\n", "\n", "plt.tight_layout()\n", "plt.savefig(f'{OUT_DIR}/training_results.png', dpi=150, bbox_inches='tight')\n", "plt.show()\n", "\n", "with open(f'{OUT_DIR}/results.json', 'w') as f:\n", " json.dump({\n", " 'baseline': baseline,\n", " 'trained': trained,\n", " 'training_log': [{'step': s, 'reward': r} for s, r in zip(train_steps, train_rewards)],\n", " 'config': {'model': MODEL_NAME, 'n_per_task': N_PER_TASK, 'num_generations': cfg.num_generations,\n", " 'epochs': cfg.num_train_epochs, 'lr': cfg.learning_rate}\n", " }, f, indent=2, default=str)\n", "print(f'✓ Saved {OUT_DIR}/training_results.png and {OUT_DIR}/results.json')" ], "id": "yZaNvORKxz1t" }, { "cell_type": "markdown", "metadata": { "id": "RAUL-h3zxz1t" }, "source": [ "## 12. Download artifacts to your laptop" ], "id": "RAUL-h3zxz1t" }, { "cell_type": "code", "metadata": { "id": "CMbTmTZLxz1t" }, "execution_count": null, "outputs": [], "source": [ "from google.colab import files\n", "files.download(f'{OUT_DIR}/training_results.png')\n", "files.download(f'{OUT_DIR}/results.json')" ], "id": "CMbTmTZLxz1t" }, { "cell_type": "markdown", "metadata": { "id": "w2wrkkAAxz1t" }, "source": [ "---\n", "\n", "**After download:**\n", "1. Drop `training_results.png` into your project's `training_logs/` folder\n", "2. Embed it in your README under a 'Training Results' section\n", "3. Commit & push to your HF Space\n", "." ], "id": "w2wrkkAAxz1t" } ] }