DevanshuDon commited on
Commit
a81cec4
Β·
verified Β·
1 Parent(s): a5018d2

Upload train_colab.ipynb

Browse files
Files changed (1) hide show
  1. train_colab.ipynb +663 -527
train_colab.ipynb CHANGED
@@ -1,530 +1,666 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 5,
4
- "metadata": {
5
- "kernelspec": {
6
- "display_name": "Python 3",
7
- "language": "python",
8
- "name": "python3"
 
 
 
 
 
 
 
 
 
9
  },
10
- "language_info": {
11
- "name": "python",
12
- "version": "3.10"
13
- },
14
- "accelerator": "GPU",
15
- "colab": {
16
- "provenance": [],
17
- "gpuType": "T4"
18
- }
19
- },
20
- "cells": [
21
- {
22
- "cell_type": "markdown",
23
- "metadata": {},
24
- "source": [
25
- "# ExecAssist \u2014 GRPO Training on Colab T4\n",
26
- "\n",
27
- "Trains Qwen2.5-0.5B-Instruct with TRL GRPO on the deployed ExecAssist environment.\n",
28
- "\n",
29
- "**Runtime:** Set runtime \u2192 T4 GPU. Total run ~45 min.\n",
30
- "\n",
31
- "**Outputs:** `training_results.png` + `results.json` for your README."
32
- ]
33
- },
34
- {
35
- "cell_type": "markdown",
36
- "metadata": {},
37
- "source": [
38
- "## 1. Install + GPU check"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "metadata": {},
44
- "execution_count": null,
45
- "outputs": [],
46
- "source": [
47
- "!pip install -q -U torch transformers trl datasets accelerate matplotlib"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "metadata": {},
53
- "execution_count": null,
54
- "outputs": [],
55
- "source": [
56
- "import torch\n",
57
- "assert torch.cuda.is_available(), 'No GPU detected \u2014 set Runtime > Change runtime type > T4 GPU'\n",
58
- "print('GPU:', torch.cuda.get_device_name(0))\n",
59
- "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')"
60
- ]
61
- },
62
- {
63
- "cell_type": "markdown",
64
- "metadata": {},
65
- "source": [
66
- "## 2. Config"
67
- ]
68
- },
69
- {
70
- "cell_type": "code",
71
- "metadata": {},
72
- "execution_count": null,
73
- "outputs": [],
74
- "source": [
75
- "import os, json, requests, copy\n",
76
- "from datetime import datetime\n",
77
- "\n",
78
- "ENV_URL = 'https://devanshudon-exec-assist.hf.space'\n",
79
- "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
80
- "N_PER_TASK = 30 # 30 each = 90 scenarios total\n",
81
- "N_EVAL = 10 # baseline/trained eval samples per task\n",
82
- "OUT_DIR = 'training_logs'\n",
83
- "os.makedirs(OUT_DIR, exist_ok=True)"
84
- ]
85
- },
86
- {
87
- "cell_type": "markdown",
88
- "metadata": {},
89
- "source": [
90
- "## 3. Collect scenarios from your HF Space\n",
91
- "\n",
92
- "Each `/reset` returns a fresh scenario. We pull 90 of them once and reuse them for both training and evaluation \u2014 this keeps reward scoring deterministic."
93
- ]
94
- },
95
- {
96
- "cell_type": "code",
97
- "metadata": {},
98
- "execution_count": null,
99
- "outputs": [],
100
- "source": [
101
- "def reset_env(task):\n",
102
- " r = requests.post(f'{ENV_URL}/reset', params={'task': task}, timeout=60)\n",
103
- " r.raise_for_status()\n",
104
- " return r.json()\n",
105
- "\n",
106
- "scenarios = []\n",
107
- "for task in ['easy', 'medium', 'hard']:\n",
108
- " for i in range(N_PER_TASK):\n",
109
- " try:\n",
110
- " data = reset_env(task)\n",
111
- " scenarios.append({'task': task, 'observation': data.get('observation', {})})\n",
112
- " except Exception as e:\n",
113
- " print(f' ! {task}#{i}: {e}')\n",
114
- " print(f' \u2713 {task}: {sum(1 for s in scenarios if s[\"task\"]==task)} collected')\n",
115
- "\n",
116
- "print(f'\\nTotal scenarios: {len(scenarios)}')"
117
- ]
118
- },
119
- {
120
- "cell_type": "markdown",
121
- "metadata": {},
122
- "source": [
123
- "## 4. Reward scoring (heuristic, deterministic, no API calls)\n",
124
- "\n",
125
- "Mirrors the env's weightings \u2014 Easy 50/50, Medium 30/40/30, Hard 34/33/33 \u2014 without the LLM judge so training stays fast."
126
- ]
127
- },
128
- {
129
- "cell_type": "code",
130
- "metadata": {},
131
- "execution_count": null,
132
- "outputs": [],
133
- "source": [
134
- "def score_email(email):\n",
135
- " if not email or len(email.strip()) < 20: return 0.0\n",
136
- " e, score = email.lower(), 0.5\n",
137
- " if any(w in e for w in ['thank','appreciate','please','regards','sincerely','best']): score += 0.15\n",
138
- " if any(e.lstrip().startswith(g) for g in ['hi ','hello ','dear ','good morning','good afternoon']): score += 0.10\n",
139
- " if any(c in e for c in ['regards','sincerely','best,','thanks,','cheers']): score += 0.10\n",
140
- " wc = len(email.split())\n",
141
- " if 25 <= wc <= 200: score += 0.15\n",
142
- " elif wc > 250: score -= 0.10\n",
143
- " elif wc < 15: score -= 0.20\n",
144
- " return max(0.0, min(1.0, score))\n",
145
- "\n",
146
- "def score_scheduling(meeting, obs):\n",
147
- " if not isinstance(meeting, dict): return 0.0\n",
148
- " has_s = 'start_time' in meeting; has_e = 'end_time' in meeting\n",
149
- " has_p = isinstance(meeting.get('participants'), list) and len(meeting.get('participants',[])) >= 2\n",
150
- " has_subj = bool(meeting.get('subject'))\n",
151
- " no_conflict = in_hours = good_dur = False\n",
152
- " if has_s and has_e:\n",
153
- " try:\n",
154
- " ns = datetime.fromisoformat(meeting['start_time'])\n",
155
- " ne = datetime.fromisoformat(meeting['end_time'])\n",
156
- " in_hours = 9 <= ns.hour < 17\n",
157
- " d = (ne - ns).total_seconds()/60\n",
158
- " good_dur = 15 <= d <= 120\n",
159
- " no_conflict = True\n",
160
- " for m in obs.get('calendar', {}).get('existing_meetings', []):\n",
161
- " try:\n",
162
- " es = datetime.fromisoformat(m['start_time'])\n",
163
- " ee = datetime.fromisoformat(m['end_time'])\n",
164
- " if not (ne <= es or ns >= ee):\n",
165
- " no_conflict = False; break\n",
166
- " except: pass\n",
167
- " except: pass\n",
168
- " checks = [has_s, has_e, has_p, has_subj, no_conflict, in_hours, good_dur]\n",
169
- " return sum(checks) / len(checks)\n",
170
- "\n",
171
- "def score_conflict(action, obs):\n",
172
- " cal_action = action.get('calendar_action','')\n",
173
- " meeting = action.get('meeting_details') or {}\n",
174
- " email = (action.get('email_reply') or '').lower()\n",
175
- " score = 0.5\n",
176
- " if any(w in email for w in ['conflict','unavailable','alternative','instead','however','unfortunately']): score += 0.15\n",
177
- " alts = meeting.get('proposed_alternatives') if isinstance(meeting, dict) else None\n",
178
- " if alts and isinstance(alts, list) and len(alts) >= 2: score += 0.25\n",
179
- " if cal_action == 'propose_alternatives': score += 0.10\n",
180
- " return max(0.0, min(1.0, score))\n",
181
- "\n",
182
- "def penalties(action):\n",
183
- " p = 0.0\n",
184
- " email = action.get('email_reply','') or ''\n",
185
- " if not email or len(email.strip()) < 20: p -= 0.30\n",
186
- " if len(email) > 1500: p -= 0.15\n",
187
- " if not action.get('meeting_details'): p -= 0.40\n",
188
- " return p\n",
189
- "\n",
190
- "def compose_reward(action, obs, task):\n",
191
- " if not isinstance(action, dict): return 0.0\n",
192
- " e = score_email(action.get('email_reply',''))\n",
193
- " s = score_scheduling(action.get('meeting_details'), obs)\n",
194
- " c = score_conflict(action, obs)\n",
195
- " p = penalties(action)\n",
196
- " if task == 'easy': base = 0.50*e + 0.50*s\n",
197
- " elif task == 'medium': base = 0.30*e + 0.40*c + 0.30*s\n",
198
- " else: base = 0.34*e + 0.33*s + 0.33*c\n",
199
- " return max(0.0, min(1.0, base + p))\n",
200
- "\n",
201
- "# Sanity check\n",
202
- "demo = {'email_reply':'Hi Sarah, Thank you for reaching out. Monday at 2pm works well. Best regards, Alex Chen',\n",
203
- " 'calendar_action':'book',\n",
204
- " 'meeting_details':{'participants':['s@x.com','a@x.com'],\n",
205
- " 'start_time':'2026-04-28T14:00:00',\n",
206
- " 'end_time':'2026-04-28T14:30:00',\n",
207
- " 'subject':'Test'}}\n",
208
- "print('Demo reward (easy):', compose_reward(demo, scenarios[0]['observation'], 'easy'))"
209
- ]
210
- },
211
- {
212
- "cell_type": "markdown",
213
- "metadata": {},
214
- "source": [
215
- "## 5. Build prompts + HF Dataset"
216
- ]
217
- },
218
- {
219
- "cell_type": "code",
220
- "metadata": {},
221
- "execution_count": null,
222
- "outputs": [],
223
- "source": [
224
- "def build_prompt(obs):\n",
225
- " emails = obs.get('emails', [])\n",
226
- " cal = obs.get('calendar', {})\n",
227
- " em_str = ''\n",
228
- " for e in emails:\n",
229
- " em_str += f\"From: {e.get('sender','?')}\\nSubject: {e.get('subject','?')}\\nPriority: {e.get('priority','normal')}\\n{e.get('body','')}\\n---\\n\"\n",
230
- " meetings = cal.get('existing_meetings', [])\n",
231
- " cal_str = 'Existing meetings:\\n' + ('\\n'.join(f\" - {m.get('subject','?')}: {m.get('start_time','?')} -> {m.get('end_time','?')}\" for m in meetings) if meetings else ' (none)')\n",
232
- " exec_name = cal.get('executive_name','Alex Chen')\n",
233
- " return f\"\"\"You are {exec_name}'s executive assistant. Reply to incoming email and book the meeting.\n",
234
- "\n",
235
- "EMAILS:\n",
236
- "{em_str}\n",
237
- "\n",
238
- "{cal_str}\n",
239
- "\n",
240
- "Working hours: 9-17. Duration: 15-120 min. Avoid double-booking.\n",
241
- "\n",
242
- "Respond with ONLY a JSON object, no commentary:\n",
243
- "{{\\\"email_reply\\\":\\\"...\\\",\\\"calendar_action\\\":\\\"book\\\",\\\"meeting_details\\\":{{\\\"participants\\\":[\\\"a@x.com\\\",\\\"b@x.com\\\"],\\\"start_time\\\":\\\"2026-04-28T14:00:00\\\",\\\"end_time\\\":\\\"2026-04-28T14:30:00\\\",\\\"subject\\\":\\\"...\\\"}}}}\"\"\"\n",
244
- "\n",
245
- "from datasets import Dataset\n",
246
- "ds = Dataset.from_dict({\n",
247
- " 'prompt': [build_prompt(s['observation']) for s in scenarios],\n",
248
- " 'task': [s['task'] for s in scenarios],\n",
249
- " 'scenario': [s['observation'] for s in scenarios],\n",
250
- "})\n",
251
- "print(f'Dataset: {len(ds)} rows')\n",
252
- "print('---\\nSample prompt (truncated):\\n', ds[0]['prompt'][:400], '...')"
253
- ]
254
- },
255
- {
256
- "cell_type": "markdown",
257
- "metadata": {},
258
- "source": [
259
- "## 6. JSON parser + GRPO reward function\n",
260
- "\n",
261
- "GRPOTrainer passes extra dataset columns (`task`, `scenario`) as kwargs to the reward function."
262
- ]
263
- },
264
- {
265
- "cell_type": "code",
266
- "metadata": {},
267
- "execution_count": null,
268
- "outputs": [],
269
- "source": [
270
- "def parse_action(text):\n",
271
- " try:\n",
272
- " s = text.find('{')\n",
273
- " if s == -1: return {}\n",
274
- " depth = 0\n",
275
- " for i, ch in enumerate(text[s:], start=s):\n",
276
- " if ch == '{': depth += 1\n",
277
- " elif ch == '}':\n",
278
- " depth -= 1\n",
279
- " if depth == 0:\n",
280
- " return json.loads(text[s:i+1])\n",
281
- " except Exception:\n",
282
- " return {}\n",
283
- " return {}\n",
284
- "\n",
285
- "def reward_function(completions, scenario, task, **kwargs):\n",
286
- " rewards = []\n",
287
- " for comp, scen, t in zip(completions, scenario, task):\n",
288
- " text = comp[-1]['content'] if isinstance(comp, list) else comp\n",
289
- " action = parse_action(text)\n",
290
- " try:\n",
291
- " r = compose_reward(action, scen, t)\n",
292
- " except Exception:\n",
293
- " r = 0.0\n",
294
- " rewards.append(float(r))\n",
295
- " return rewards\n",
296
- "\n",
297
- "# Smoke test\n",
298
- "demo_comp = json.dumps(demo)\n",
299
- "print('Reward fn output:', reward_function([demo_comp], [scenarios[0]['observation']], ['easy']))"
300
- ]
301
- },
302
- {
303
- "cell_type": "markdown",
304
- "metadata": {},
305
- "source": [
306
- "## 7. Load model + tokenizer"
307
- ]
308
- },
309
- {
310
- "cell_type": "code",
311
- "metadata": {},
312
- "execution_count": null,
313
- "outputs": [],
314
- "source": [
315
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
316
- "\n",
317
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
318
- "if tokenizer.pad_token is None:\n",
319
- " tokenizer.pad_token = tokenizer.eos_token\n",
320
- "\n",
321
- "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to('cuda')\n",
322
- "print(f'Loaded {MODEL_NAME} \u2014 {sum(p.numel() for p in model.parameters())/1e6:.0f}M params')"
323
- ]
324
- },
325
- {
326
- "cell_type": "markdown",
327
- "metadata": {},
328
- "source": [
329
- "## 8. Baseline evaluation (untrained Qwen-0.5B)"
330
- ]
331
- },
332
- {
333
- "cell_type": "code",
334
- "metadata": {},
335
- "execution_count": null,
336
- "outputs": [],
337
- "source": [
338
- "def evaluate(m, tok, dataset, n_per_task=10):\n",
339
- " m.eval()\n",
340
- " results = {'easy':[], 'medium':[], 'hard':[]}\n",
341
- " for ex in dataset:\n",
342
- " t = ex['task']\n",
343
- " if len(results[t]) >= n_per_task: continue\n",
344
- " msgs = [{'role':'user','content': ex['prompt']}]\n",
345
- " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
346
- " ids = tok(text, return_tensors='pt', truncation=True, max_length=1024).to('cuda')\n",
347
- " with torch.no_grad():\n",
348
- " out = m.generate(**ids, max_new_tokens=350, do_sample=True, temperature=0.7,\n",
349
- " top_p=0.9, pad_token_id=tok.eos_token_id)\n",
350
- " comp = tok.decode(out[0][ids['input_ids'].shape[1]:], skip_special_tokens=True)\n",
351
- " action = parse_action(comp)\n",
352
- " r = compose_reward(action, ex['scenario'], t) if action else 0.0\n",
353
- " results[t].append(r)\n",
354
- " return results\n",
355
- "\n",
356
- "print('Evaluating baseline...')\n",
357
- "baseline = evaluate(model, tokenizer, ds, n_per_task=N_EVAL)\n",
358
- "for t, rs in baseline.items():\n",
359
- " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')"
360
- ]
361
- },
362
- {
363
- "cell_type": "markdown",
364
- "metadata": {},
365
- "source": [
366
- "## 9. Train with TRL GRPO\n",
367
- "\n",
368
- "Hyperparameters chosen for T4 + 0.5B model. Should complete in ~25-35 min."
369
- ]
370
- },
371
- {
372
- "cell_type": "code",
373
- "metadata": {},
374
- "execution_count": null,
375
- "outputs": [],
376
- "source": [
377
- "from trl import GRPOTrainer, GRPOConfig\n",
378
- "\n",
379
- "cfg = GRPOConfig(\n",
380
- " output_dir='./grpo_out',\n",
381
- " learning_rate=5e-6,\n",
382
- " per_device_train_batch_size=2,\n",
383
- " gradient_accumulation_steps=4,\n",
384
- " num_generations=4,\n",
385
- " max_prompt_length=768,\n",
386
- " max_completion_length=300,\n",
387
- " num_train_epochs=1,\n",
388
- " logging_steps=1,\n",
389
- " save_strategy='no',\n",
390
- " fp16=True,\n",
391
- " report_to='none',\n",
392
- " temperature=0.9,\n",
393
- " beta=0.04,\n",
394
- " gradient_checkpointing=True,\n",
395
- ")\n",
396
- "\n",
397
- "trainer = GRPOTrainer(\n",
398
- " model=model,\n",
399
- " args=cfg,\n",
400
- " train_dataset=ds,\n",
401
- " reward_funcs=reward_function,\n",
402
- " processing_class=tokenizer,\n",
403
- ")\n",
404
- "\n",
405
- "print('Starting GRPO training...')\n",
406
- "trainer.train()\n",
407
- "print('\u2705 Training complete')"
408
- ]
409
- },
410
- {
411
- "cell_type": "markdown",
412
- "metadata": {},
413
- "source": [
414
- "## 10. Post-training evaluation"
415
- ]
416
- },
417
- {
418
- "cell_type": "code",
419
- "metadata": {},
420
- "execution_count": null,
421
- "outputs": [],
422
- "source": [
423
- "print('Evaluating trained model...')\n",
424
- "trained = evaluate(trainer.model, tokenizer, ds, n_per_task=N_EVAL)\n",
425
- "for t, rs in trained.items():\n",
426
- " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')\n",
427
- "\n",
428
- "# Compare\n",
429
- "print('\\n Baseline \u2192 Trained')\n",
430
- "for t in ['easy','medium','hard']:\n",
431
- " b = sum(baseline[t])/max(len(baseline[t]),1)\n",
432
- " tr = sum(trained[t])/max(len(trained[t]),1)\n",
433
- " delta = (tr - b) / max(b, 1e-6) * 100\n",
434
- " print(f' {t:<7} {b:.3f} \u2192 {tr:.3f} ({delta:+.1f}%)')"
435
- ]
436
- },
437
- {
438
- "cell_type": "markdown",
439
- "metadata": {},
440
- "source": [
441
- "## 11. Plots + save"
442
- ]
443
- },
444
- {
445
- "cell_type": "code",
446
- "metadata": {},
447
- "execution_count": null,
448
- "outputs": [],
449
- "source": [
450
- "import matplotlib.pyplot as plt\n",
451
- "import numpy as np\n",
452
- "\n",
453
- "tasks = ['easy','medium','hard']\n",
454
- "b_avg = [np.mean(baseline[t]) for t in tasks]\n",
455
- "t_avg = [np.mean(trained[t]) for t in tasks]\n",
456
- "\n",
457
- "log = trainer.state.log_history\n",
458
- "train_steps = [e.get('step') for e in log if 'reward' in e]\n",
459
- "train_rewards = [e.get('reward') for e in log if 'reward' in e]\n",
460
- "\n",
461
- "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
462
- "\n",
463
- "# Bar: baseline vs trained\n",
464
- "x = np.arange(len(tasks)); w = 0.35\n",
465
- "axes[0].bar(x - w/2, b_avg, w, label='Baseline (Qwen2.5-0.5B, untrained)', color='#e74c3c', alpha=0.85)\n",
466
- "axes[0].bar(x + w/2, t_avg, w, label='Trained (GRPO)', color='#2ecc71', alpha=0.85)\n",
467
- "axes[0].set_xticks(x); axes[0].set_xticklabels([t.capitalize() for t in tasks])\n",
468
- "axes[0].set_ylabel('Mean reward'); axes[0].set_title('Baseline vs Trained \u2014 per task')\n",
469
- "axes[0].legend(); axes[0].grid(axis='y', alpha=0.3); axes[0].set_ylim(0, 1)\n",
470
- "for i, (b, tr) in enumerate(zip(b_avg, t_avg)):\n",
471
- " axes[0].text(i - w/2, b + 0.02, f'{b:.3f}', ha='center', fontsize=9)\n",
472
- " axes[0].text(i + w/2, tr + 0.02, f'{tr:.3f}', ha='center', fontsize=9)\n",
473
- "\n",
474
- "# Line: training reward over steps\n",
475
- "if train_rewards:\n",
476
- " axes[1].plot(train_steps, train_rewards, marker='o', color='#3498db', linewidth=2, markersize=4)\n",
477
- " axes[1].set_xlabel('Training step'); axes[1].set_ylabel('Mean reward (batch)')\n",
478
- " axes[1].set_title('GRPO \u2014 reward over training steps')\n",
479
- " axes[1].grid(alpha=0.3)\n",
480
- "else:\n",
481
- " axes[1].text(0.5, 0.5, 'No training log captured', ha='center', va='center', transform=axes[1].transAxes)\n",
482
- "\n",
483
- "plt.tight_layout()\n",
484
- "plt.savefig(f'{OUT_DIR}/training_results.png', dpi=150, bbox_inches='tight')\n",
485
- "plt.show()\n",
486
- "\n",
487
- "with open(f'{OUT_DIR}/results.json', 'w') as f:\n",
488
- " json.dump({\n",
489
- " 'baseline': baseline,\n",
490
- " 'trained': trained,\n",
491
- " 'training_log': [{'step': s, 'reward': r} for s, r in zip(train_steps, train_rewards)],\n",
492
- " 'config': {'model': MODEL_NAME, 'n_per_task': N_PER_TASK, 'num_generations': cfg.num_generations,\n",
493
- " 'epochs': cfg.num_train_epochs, 'lr': cfg.learning_rate}\n",
494
- " }, f, indent=2, default=str)\n",
495
- "print(f'\u2713 Saved {OUT_DIR}/training_results.png and {OUT_DIR}/results.json')"
496
- ]
497
- },
498
- {
499
- "cell_type": "markdown",
500
- "metadata": {},
501
- "source": [
502
- "## 12. Download artifacts to your laptop"
503
- ]
504
- },
505
- {
506
- "cell_type": "code",
507
- "metadata": {},
508
- "execution_count": null,
509
- "outputs": [],
510
- "source": [
511
- "from google.colab import files\n",
512
- "files.download(f'{OUT_DIR}/training_results.png')\n",
513
- "files.download(f'{OUT_DIR}/results.json')"
514
- ]
515
- },
516
- {
517
- "cell_type": "markdown",
518
- "metadata": {},
519
- "source": [
520
- "---\n",
521
- "\n",
522
- "**After download:**\n",
523
- "1. Drop `training_results.png` into your project's `training_logs/` folder\n",
524
- "2. Embed it in your README under a 'Training Results' section\n",
525
- "3. Commit & push to your HF Space\n",
526
- "4. You're done \u2014 switch to writing the mini-blog (Opus session)."
527
- ]
528
- }
529
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  }
 
1
  {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "name": "python3"
8
+ },
9
+ "language_info": {
10
+ "name": "python",
11
+ "version": "3.10"
12
+ },
13
+ "accelerator": "GPU",
14
+ "colab": {
15
+ "provenance": [],
16
+ "gpuType": "T4"
17
+ }
18
  },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "id": "ZXq1I39Jxz1q"
24
+ },
25
+ "source": [
26
+ "# ExecAssist β€” GRPO Training on Colab T4\n",
27
+ "\n",
28
+ "Trains Qwen2.5-0.5B-Instruct with TRL GRPO on the deployed ExecAssist environment.\n",
29
+ "\n",
30
+ "**Runtime:** Set runtime β†’ T4 GPU. Total run ~45 min.\n",
31
+ "\n",
32
+ "**Outputs:** `training_results.png` + `results.json` for your README."
33
+ ],
34
+ "id": "ZXq1I39Jxz1q"
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {
39
+ "id": "p1d4Lkp_xz1r"
40
+ },
41
+ "source": [
42
+ "## 1. Install + GPU check"
43
+ ],
44
+ "id": "p1d4Lkp_xz1r"
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "metadata": {
49
+ "colab": {
50
+ "base_uri": "https://localhost:8080/"
51
+ },
52
+ "id": "fXDfPHQvxz1r",
53
+ "outputId": "afb1898b-7db8-4783-8b5d-619d9bbd600e"
54
+ },
55
+ "execution_count": null,
56
+ "outputs": [
57
+ {
58
+ "output_type": "stream",
59
+ "name": "stdout",
60
+ "text": [
61
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.8/52.8 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
62
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m530.7/530.7 MB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
63
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m366.1/366.1 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
64
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m169.9/169.9 MB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
65
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.5/196.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
66
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.4/60.4 MB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
67
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.3/188.3 MB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
68
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m423.1/423.1 MB\u001b[0m \u001b[31m824.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
69
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.7/10.7 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
70
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.2/90.2 MB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
71
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m81.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
72
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m214.1/214.1 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
73
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m47.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
74
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.5/59.5 MB\u001b[0m \u001b[31m12.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
75
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.9/200.9 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
76
+ "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91mβ•Έ\u001b[0m \u001b[32m145.9/145.9 MB\u001b[0m \u001b[31m131.5 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m"
77
+ ]
78
+ }
79
+ ],
80
+ "source": [
81
+ "!pip uninstall -y torchvision torchaudio -q\n",
82
+ "!pip install -q -U torch transformers trl datasets accelerate matplotlib"
83
+ ],
84
+ "id": "fXDfPHQvxz1r"
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "metadata": {
89
+ "id": "QOqVp9nrxz1r"
90
+ },
91
+ "execution_count": null,
92
+ "outputs": [],
93
+ "source": [
94
+ "import torch\n",
95
+ "assert torch.cuda.is_available(), 'No GPU detected β€” set Runtime > Change runtime type > T4 GPU'\n",
96
+ "print('GPU:', torch.cuda.get_device_name(0))\n",
97
+ "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')"
98
+ ],
99
+ "id": "QOqVp9nrxz1r"
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "metadata": {
104
+ "id": "NeEF7A1mxz1s"
105
+ },
106
+ "source": [
107
+ "## 2. Config"
108
+ ],
109
+ "id": "NeEF7A1mxz1s"
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "metadata": {
114
+ "id": "YcYdPAMNxz1s"
115
+ },
116
+ "execution_count": null,
117
+ "outputs": [],
118
+ "source": [
119
+ "import os, json, requests, copy\n",
120
+ "from datetime import datetime\n",
121
+ "\n",
122
+ "ENV_URL = 'https://devanshudon-exec-assist.hf.space'\n",
123
+ "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
124
+ "N_PER_TASK = 30 # 30 each = 90 scenarios total\n",
125
+ "N_EVAL = 10 # baseline/trained eval samples per task\n",
126
+ "OUT_DIR = 'training_logs'\n",
127
+ "os.makedirs(OUT_DIR, exist_ok=True)"
128
+ ],
129
+ "id": "YcYdPAMNxz1s"
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "metadata": {
134
+ "id": "2ODWK44mxz1s"
135
+ },
136
+ "source": [
137
+ "## 3. Collect scenarios from your HF Space\n",
138
+ "\n",
139
+ "Each `/reset` returns a fresh scenario. We pull 90 of them once and reuse them for both training and evaluation β€” this keeps reward scoring deterministic."
140
+ ],
141
+ "id": "2ODWK44mxz1s"
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "metadata": {
146
+ "id": "_BpUrmBOxz1s"
147
+ },
148
+ "execution_count": null,
149
+ "outputs": [],
150
+ "source": [
151
+ "def reset_env(task):\n",
152
+ " r = requests.post(f'{ENV_URL}/reset', params={'task': task}, timeout=60)\n",
153
+ " r.raise_for_status()\n",
154
+ " return r.json()\n",
155
+ "\n",
156
+ "scenarios = []\n",
157
+ "for task in ['easy', 'medium', 'hard']:\n",
158
+ " for i in range(N_PER_TASK):\n",
159
+ " try:\n",
160
+ " data = reset_env(task)\n",
161
+ " scenarios.append({'task': task, 'observation': data.get('observation', {})})\n",
162
+ " except Exception as e:\n",
163
+ " print(f' ! {task}#{i}: {e}')\n",
164
+ " print(f' βœ“ {task}: {sum(1 for s in scenarios if s[\"task\"]==task)} collected')\n",
165
+ "\n",
166
+ "print(f'\\nTotal scenarios: {len(scenarios)}')"
167
+ ],
168
+ "id": "_BpUrmBOxz1s"
169
+ },
170
+ {
171
+ "cell_type": "markdown",
172
+ "metadata": {
173
+ "id": "hjRq6oNfxz1s"
174
+ },
175
+ "source": [
176
+ "## 4. Reward scoring (heuristic, deterministic, no API calls)\n",
177
+ "\n",
178
+ "Mirrors the env's weightings β€” Easy 50/50, Medium 30/40/30, Hard 34/33/33 β€” without the LLM judge so training stays fast."
179
+ ],
180
+ "id": "hjRq6oNfxz1s"
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "metadata": {
185
+ "id": "nprQokwPxz1s"
186
+ },
187
+ "execution_count": null,
188
+ "outputs": [],
189
+ "source": [
190
+ "def score_email(email):\n",
191
+ " if not email or len(email.strip()) < 20: return 0.0\n",
192
+ " e, score = email.lower(), 0.5\n",
193
+ " if any(w in e for w in ['thank','appreciate','please','regards','sincerely','best']): score += 0.15\n",
194
+ " if any(e.lstrip().startswith(g) for g in ['hi ','hello ','dear ','good morning','good afternoon']): score += 0.10\n",
195
+ " if any(c in e for c in ['regards','sincerely','best,','thanks,','cheers']): score += 0.10\n",
196
+ " wc = len(email.split())\n",
197
+ " if 25 <= wc <= 200: score += 0.15\n",
198
+ " elif wc > 250: score -= 0.10\n",
199
+ " elif wc < 15: score -= 0.20\n",
200
+ " return max(0.0, min(1.0, score))\n",
201
+ "\n",
202
+ "def score_scheduling(meeting, obs):\n",
203
+ " if not isinstance(meeting, dict): return 0.0\n",
204
+ " has_s = 'start_time' in meeting; has_e = 'end_time' in meeting\n",
205
+ " has_p = isinstance(meeting.get('participants'), list) and len(meeting.get('participants',[])) >= 2\n",
206
+ " has_subj = bool(meeting.get('subject'))\n",
207
+ " no_conflict = in_hours = good_dur = False\n",
208
+ " if has_s and has_e:\n",
209
+ " try:\n",
210
+ " ns = datetime.fromisoformat(meeting['start_time'])\n",
211
+ " ne = datetime.fromisoformat(meeting['end_time'])\n",
212
+ " in_hours = 9 <= ns.hour < 17\n",
213
+ " d = (ne - ns).total_seconds()/60\n",
214
+ " good_dur = 15 <= d <= 120\n",
215
+ " no_conflict = True\n",
216
+ " for m in obs.get('calendar', {}).get('existing_meetings', []):\n",
217
+ " try:\n",
218
+ " es = datetime.fromisoformat(m['start_time'])\n",
219
+ " ee = datetime.fromisoformat(m['end_time'])\n",
220
+ " if not (ne <= es or ns >= ee):\n",
221
+ " no_conflict = False; break\n",
222
+ " except: pass\n",
223
+ " except: pass\n",
224
+ " checks = [has_s, has_e, has_p, has_subj, no_conflict, in_hours, good_dur]\n",
225
+ " return sum(checks) / len(checks)\n",
226
+ "\n",
227
+ "def score_conflict(action, obs):\n",
228
+ " cal_action = action.get('calendar_action','')\n",
229
+ " meeting = action.get('meeting_details') or {}\n",
230
+ " email = (action.get('email_reply') or '').lower()\n",
231
+ " score = 0.5\n",
232
+ " if any(w in email for w in ['conflict','unavailable','alternative','instead','however','unfortunately']): score += 0.15\n",
233
+ " alts = meeting.get('proposed_alternatives') if isinstance(meeting, dict) else None\n",
234
+ " if alts and isinstance(alts, list) and len(alts) >= 2: score += 0.25\n",
235
+ " if cal_action == 'propose_alternatives': score += 0.10\n",
236
+ " return max(0.0, min(1.0, score))\n",
237
+ "\n",
238
+ "def penalties(action):\n",
239
+ " p = 0.0\n",
240
+ " email = action.get('email_reply','') or ''\n",
241
+ " if not email or len(email.strip()) < 20: p -= 0.30\n",
242
+ " if len(email) > 1500: p -= 0.15\n",
243
+ " if not action.get('meeting_details'): p -= 0.40\n",
244
+ " return p\n",
245
+ "\n",
246
+ "def compose_reward(action, obs, task):\n",
247
+ " if not isinstance(action, dict): return 0.0\n",
248
+ " e = score_email(action.get('email_reply',''))\n",
249
+ " s = score_scheduling(action.get('meeting_details'), obs)\n",
250
+ " c = score_conflict(action, obs)\n",
251
+ " p = penalties(action)\n",
252
+ " if task == 'easy': base = 0.50*e + 0.50*s\n",
253
+ " elif task == 'medium': base = 0.30*e + 0.40*c + 0.30*s\n",
254
+ " else: base = 0.34*e + 0.33*s + 0.33*c\n",
255
+ " return max(0.0, min(1.0, base + p))\n",
256
+ "\n",
257
+ "# Sanity check\n",
258
+ "demo = {'email_reply':'Hi Sarah, Thank you for reaching out. Monday at 2pm works well. Best regards, Alex Chen',\n",
259
+ " 'calendar_action':'book',\n",
260
+ " 'meeting_details':{'participants':['s@x.com','a@x.com'],\n",
261
+ " 'start_time':'2026-04-28T14:00:00',\n",
262
+ " 'end_time':'2026-04-28T14:30:00',\n",
263
+ " 'subject':'Test'}}\n",
264
+ "print('Demo reward (easy):', compose_reward(demo, scenarios[0]['observation'], 'easy'))"
265
+ ],
266
+ "id": "nprQokwPxz1s"
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "metadata": {
271
+ "id": "4nS43nVkxz1s"
272
+ },
273
+ "source": [
274
+ "## 5. Build prompts + HF Dataset"
275
+ ],
276
+ "id": "4nS43nVkxz1s"
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "metadata": {
281
+ "id": "rfhuu5Ebxz1s"
282
+ },
283
+ "execution_count": null,
284
+ "outputs": [],
285
+ "source": [
286
+ "def build_prompt(obs):\n",
287
+ " emails = obs.get('emails', [])\n",
288
+ " cal = obs.get('calendar', {})\n",
289
+ " em_str = ''\n",
290
+ " for e in emails:\n",
291
+ " em_str += f\"From: {e.get('sender','?')}\\nSubject: {e.get('subject','?')}\\nPriority: {e.get('priority','normal')}\\n{e.get('body','')}\\n---\\n\"\n",
292
+ " meetings = cal.get('existing_meetings', [])\n",
293
+ " cal_str = 'Existing meetings:\\n' + ('\\n'.join(f\" - {m.get('subject','?')}: {m.get('start_time','?')} -> {m.get('end_time','?')}\" for m in meetings) if meetings else ' (none)')\n",
294
+ " exec_name = cal.get('executive_name','Alex Chen')\n",
295
+ " return f\"\"\"You are {exec_name}'s executive assistant. Reply to incoming email and book the meeting.\n",
296
+ "\n",
297
+ "EMAILS:\n",
298
+ "{em_str}\n",
299
+ "\n",
300
+ "{cal_str}\n",
301
+ "\n",
302
+ "Working hours: 9-17. Duration: 15-120 min. Avoid double-booking.\n",
303
+ "\n",
304
+ "Respond with ONLY a JSON object, no commentary:\n",
305
+ "{{\\\"email_reply\\\":\\\"...\\\",\\\"calendar_action\\\":\\\"book\\\",\\\"meeting_details\\\":{{\\\"participants\\\":[\\\"a@x.com\\\",\\\"b@x.com\\\"],\\\"start_time\\\":\\\"2026-04-28T14:00:00\\\",\\\"end_time\\\":\\\"2026-04-28T14:30:00\\\",\\\"subject\\\":\\\"...\\\"}}}}\"\"\"\n",
306
+ "\n",
307
+ "from datasets import Dataset\n",
308
+ "ds = Dataset.from_dict({\n",
309
+ " 'prompt': [build_prompt(s['observation']) for s in scenarios],\n",
310
+ " 'task': [s['task'] for s in scenarios],\n",
311
+ " 'scenario': [s['observation'] for s in scenarios],\n",
312
+ "})\n",
313
+ "print(f'Dataset: {len(ds)} rows')\n",
314
+ "print('---\\nSample prompt (truncated):\\n', ds[0]['prompt'][:400], '...')"
315
+ ],
316
+ "id": "rfhuu5Ebxz1s"
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "metadata": {
321
+ "id": "9ztlxmy3xz1t"
322
+ },
323
+ "source": [
324
+ "## 6. JSON parser + GRPO reward function\n",
325
+ "\n",
326
+ "GRPOTrainer passes extra dataset columns (`task`, `scenario`) as kwargs to the reward function."
327
+ ],
328
+ "id": "9ztlxmy3xz1t"
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "metadata": {
333
+ "id": "OxPpcvbzxz1t"
334
+ },
335
+ "execution_count": null,
336
+ "outputs": [],
337
+ "source": [
338
+ "def parse_action(text):\n",
339
+ " try:\n",
340
+ " s = text.find('{')\n",
341
+ " if s == -1: return {}\n",
342
+ " depth = 0\n",
343
+ " for i, ch in enumerate(text[s:], start=s):\n",
344
+ " if ch == '{': depth += 1\n",
345
+ " elif ch == '}':\n",
346
+ " depth -= 1\n",
347
+ " if depth == 0:\n",
348
+ " return json.loads(text[s:i+1])\n",
349
+ " except Exception:\n",
350
+ " return {}\n",
351
+ " return {}\n",
352
+ "\n",
353
+ "def reward_function(completions, scenario, task, **kwargs):\n",
354
+ " rewards = []\n",
355
+ " for comp, scen, t in zip(completions, scenario, task):\n",
356
+ " text = comp[-1]['content'] if isinstance(comp, list) else comp\n",
357
+ " action = parse_action(text)\n",
358
+ " try:\n",
359
+ " r = compose_reward(action, scen, t)\n",
360
+ " except Exception:\n",
361
+ " r = 0.0\n",
362
+ " rewards.append(float(r))\n",
363
+ " return rewards\n",
364
+ "\n",
365
+ "# Smoke test\n",
366
+ "demo_comp = json.dumps(demo)\n",
367
+ "print('Reward fn output:', reward_function([demo_comp], [scenarios[0]['observation']], ['easy']))"
368
+ ],
369
+ "id": "OxPpcvbzxz1t"
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {
374
+ "id": "fIDady-Oxz1t"
375
+ },
376
+ "source": [
377
+ "## 7. Load model + tokenizer"
378
+ ],
379
+ "id": "fIDady-Oxz1t"
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "metadata": {
384
+ "id": "U-GrWh3uxz1t"
385
+ },
386
+ "execution_count": null,
387
+ "outputs": [],
388
+ "source": [
389
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
390
+ "\n",
391
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
392
+ "if tokenizer.pad_token is None:\n",
393
+ " tokenizer.pad_token = tokenizer.eos_token\n",
394
+ "\n",
395
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to('cuda')\n",
396
+ "print(f'Loaded {MODEL_NAME} β€” {sum(p.numel() for p in model.parameters())/1e6:.0f}M params')"
397
+ ],
398
+ "id": "U-GrWh3uxz1t"
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "metadata": {
403
+ "id": "yWlKwQnaxz1t"
404
+ },
405
+ "source": [
406
+ "## 8. Baseline evaluation (untrained Qwen-0.5B)"
407
+ ],
408
+ "id": "yWlKwQnaxz1t"
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "metadata": {
413
+ "id": "8aM2JB7pxz1t"
414
+ },
415
+ "execution_count": null,
416
+ "outputs": [],
417
+ "source": [
418
+ "def evaluate(m, tok, dataset, n_per_task=10):\n",
419
+ " m.eval()\n",
420
+ " results = {'easy':[], 'medium':[], 'hard':[]}\n",
421
+ " for ex in dataset:\n",
422
+ " t = ex['task']\n",
423
+ " if len(results[t]) >= n_per_task: continue\n",
424
+ " msgs = [{'role':'user','content': ex['prompt']}]\n",
425
+ " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
426
+ " ids = tok(text, return_tensors='pt', truncation=True, max_length=1024).to('cuda')\n",
427
+ " with torch.no_grad():\n",
428
+ " out = m.generate(**ids, max_new_tokens=350, do_sample=True, temperature=0.7,\n",
429
+ " top_p=0.9, pad_token_id=tok.eos_token_id)\n",
430
+ " comp = tok.decode(out[0][ids['input_ids'].shape[1]:], skip_special_tokens=True)\n",
431
+ " action = parse_action(comp)\n",
432
+ " r = compose_reward(action, ex['scenario'], t) if action else 0.0\n",
433
+ " results[t].append(r)\n",
434
+ " return results\n",
435
+ "\n",
436
+ "print('Evaluating baseline...')\n",
437
+ "baseline = evaluate(model, tokenizer, ds, n_per_task=N_EVAL)\n",
438
+ "for t, rs in baseline.items():\n",
439
+ " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')"
440
+ ],
441
+ "id": "8aM2JB7pxz1t"
442
+ },
443
+ {
444
+ "cell_type": "markdown",
445
+ "metadata": {
446
+ "id": "QbsM0OX1xz1t"
447
+ },
448
+ "source": [
449
+ "## 9. Train with TRL GRPO\n",
450
+ "\n",
451
+ "Hyperparameters chosen for T4 + 0.5B model. Should complete in ~25-35 min."
452
+ ],
453
+ "id": "QbsM0OX1xz1t"
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "source": [
458
+ "# Cast model to float32 for stable GRPO training\n",
459
+ "model = model.float()\n",
460
+ "print(\"Model dtype:\", next(model.parameters()).dtype) # should print: torch.float32"
461
+ ],
462
+ "metadata": {
463
+ "id": "aWkbPpIf3SBN"
464
+ },
465
+ "id": "aWkbPpIf3SBN",
466
+ "execution_count": null,
467
+ "outputs": []
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "source": [
472
+ "# Reload clean model (undo bad training)\n",
473
+ "from transformers import AutoModelForCausalLM\n",
474
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).float()\n",
475
+ "print(\"βœ… Clean model reloaded\")"
476
+ ],
477
+ "metadata": {
478
+ "id": "uSDzeaKu9vvg"
479
+ },
480
+ "id": "uSDzeaKu9vvg",
481
+ "execution_count": null,
482
+ "outputs": []
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "metadata": {
487
+ "id": "e2bJ5jmUxz1t"
488
+ },
489
+ "execution_count": null,
490
+ "outputs": [],
491
+ "source": [
492
+ "from trl import GRPOTrainer, GRPOConfig\n",
493
+ "\n",
494
+ "cfg = GRPOConfig(\n",
495
+ " output_dir='./grpo_out',\n",
496
+ " learning_rate=1e-6, # ← Slower learning (was 5e-6, too aggressive)\n",
497
+ " per_device_train_batch_size=2,\n",
498
+ " gradient_accumulation_steps=4,\n",
499
+ " num_generations=8, # ← More variety per step (was 4, too few)\n",
500
+ " num_train_epochs=3, # ← Train longer (was 1, too short)\n",
501
+ " logging_steps=1,\n",
502
+ " save_strategy='no',\n",
503
+ " fp16=False,\n",
504
+ " bf16=False,\n",
505
+ " beta=0.1, # ← KL penalty - stops the collapse (was missing)\n",
506
+ " report_to='wandb', # experimental tracking enabled (run `wandb login` first)\n",
507
+ "\n",
508
+ " gradient_checkpointing=True,\n",
509
+ ")\n",
510
+ "\n",
511
+ "trainer = GRPOTrainer(\n",
512
+ " model=model,\n",
513
+ " args=cfg,\n",
514
+ " train_dataset=ds,\n",
515
+ " reward_funcs=reward_function,\n",
516
+ " processing_class=tokenizer,\n",
517
+ ")\n",
518
+ "\n",
519
+ "print(\"Starting GRPO training (fixed)...\")\n",
520
+ "trainer.train()\n",
521
+ "print(\"βœ… Training complete\")"
522
+ ],
523
+ "id": "e2bJ5jmUxz1t"
524
+ },
525
+ {
526
+ "cell_type": "markdown",
527
+ "metadata": {
528
+ "id": "Zllt81y0xz1t"
529
+ },
530
+ "source": [
531
+ "## 10. Post-training evaluation"
532
+ ],
533
+ "id": "Zllt81y0xz1t"
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "metadata": {
538
+ "id": "qikGbS98xz1t"
539
+ },
540
+ "execution_count": null,
541
+ "outputs": [],
542
+ "source": [
543
+ "print('Evaluating trained model...')\n",
544
+ "trained = evaluate(trainer.model, tokenizer, ds, n_per_task=N_EVAL)\n",
545
+ "for t, rs in trained.items():\n",
546
+ " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')\n",
547
+ "\n",
548
+ "# Compare\n",
549
+ "print('\\n Baseline β†’ Trained')\n",
550
+ "for t in ['easy','medium','hard']:\n",
551
+ " b = sum(baseline[t])/max(len(baseline[t]),1)\n",
552
+ " tr = sum(trained[t])/max(len(trained[t]),1)\n",
553
+ " delta = (tr - b) / max(b, 1e-6) * 100\n",
554
+ " print(f' {t:<7} {b:.3f} β†’ {tr:.3f} ({delta:+.1f}%)')"
555
+ ],
556
+ "id": "qikGbS98xz1t"
557
+ },
558
+ {
559
+ "cell_type": "markdown",
560
+ "metadata": {
561
+ "id": "FuVJHC5Dxz1t"
562
+ },
563
+ "source": [
564
+ "## 11. Plots + save"
565
+ ],
566
+ "id": "FuVJHC5Dxz1t"
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "metadata": {
571
+ "id": "yZaNvORKxz1t"
572
+ },
573
+ "execution_count": null,
574
+ "outputs": [],
575
+ "source": [
576
+ "import matplotlib.pyplot as plt\n",
577
+ "import numpy as np\n",
578
+ "\n",
579
+ "tasks = ['easy','medium','hard']\n",
580
+ "b_avg = [np.mean(baseline[t]) for t in tasks]\n",
581
+ "t_avg = [np.mean(trained[t]) for t in tasks]\n",
582
+ "\n",
583
+ "log = trainer.state.log_history\n",
584
+ "train_steps = [e.get('step') for e in log if 'reward' in e]\n",
585
+ "train_rewards = [e.get('reward') for e in log if 'reward' in e]\n",
586
+ "\n",
587
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
588
+ "\n",
589
+ "# Bar: baseline vs trained\n",
590
+ "x = np.arange(len(tasks)); w = 0.35\n",
591
+ "axes[0].bar(x - w/2, b_avg, w, label='Baseline (Qwen2.5-0.5B, untrained)', color='#e74c3c', alpha=0.85)\n",
592
+ "axes[0].bar(x + w/2, t_avg, w, label='Trained (GRPO)', color='#2ecc71', alpha=0.85)\n",
593
+ "axes[0].set_xticks(x); axes[0].set_xticklabels([t.capitalize() for t in tasks])\n",
594
+ "axes[0].set_ylabel('Mean reward'); axes[0].set_title('Baseline vs Trained β€” per task')\n",
595
+ "axes[0].legend(); axes[0].grid(axis='y', alpha=0.3); axes[0].set_ylim(0, 1)\n",
596
+ "for i, (b, tr) in enumerate(zip(b_avg, t_avg)):\n",
597
+ " axes[0].text(i - w/2, b + 0.02, f'{b:.3f}', ha='center', fontsize=9)\n",
598
+ " axes[0].text(i + w/2, tr + 0.02, f'{tr:.3f}', ha='center', fontsize=9)\n",
599
+ "\n",
600
+ "# Line: training reward over steps\n",
601
+ "if train_rewards:\n",
602
+ " axes[1].plot(train_steps, train_rewards, marker='o', color='#3498db', linewidth=2, markersize=4)\n",
603
+ " axes[1].set_xlabel('Training step'); axes[1].set_ylabel('Mean reward (batch)')\n",
604
+ " axes[1].set_title('GRPO β€” reward over training steps')\n",
605
+ " axes[1].grid(alpha=0.3)\n",
606
+ "else:\n",
607
+ " axes[1].text(0.5, 0.5, 'No training log captured', ha='center', va='center', transform=axes[1].transAxes)\n",
608
+ "\n",
609
+ "plt.tight_layout()\n",
610
+ "plt.savefig(f'{OUT_DIR}/training_results.png', dpi=150, bbox_inches='tight')\n",
611
+ "plt.show()\n",
612
+ "\n",
613
+ "with open(f'{OUT_DIR}/results.json', 'w') as f:\n",
614
+ " json.dump({\n",
615
+ " 'baseline': baseline,\n",
616
+ " 'trained': trained,\n",
617
+ " 'training_log': [{'step': s, 'reward': r} for s, r in zip(train_steps, train_rewards)],\n",
618
+ " 'config': {'model': MODEL_NAME, 'n_per_task': N_PER_TASK, 'num_generations': cfg.num_generations,\n",
619
+ " 'epochs': cfg.num_train_epochs, 'lr': cfg.learning_rate}\n",
620
+ " }, f, indent=2, default=str)\n",
621
+ "print(f'βœ“ Saved {OUT_DIR}/training_results.png and {OUT_DIR}/results.json')"
622
+ ],
623
+ "id": "yZaNvORKxz1t"
624
+ },
625
+ {
626
+ "cell_type": "markdown",
627
+ "metadata": {
628
+ "id": "RAUL-h3zxz1t"
629
+ },
630
+ "source": [
631
+ "## 12. Download artifacts to your laptop"
632
+ ],
633
+ "id": "RAUL-h3zxz1t"
634
+ },
635
+ {
636
+ "cell_type": "code",
637
+ "metadata": {
638
+ "id": "CMbTmTZLxz1t"
639
+ },
640
+ "execution_count": null,
641
+ "outputs": [],
642
+ "source": [
643
+ "from google.colab import files\n",
644
+ "files.download(f'{OUT_DIR}/training_results.png')\n",
645
+ "files.download(f'{OUT_DIR}/results.json')"
646
+ ],
647
+ "id": "CMbTmTZLxz1t"
648
+ },
649
+ {
650
+ "cell_type": "markdown",
651
+ "metadata": {
652
+ "id": "w2wrkkAAxz1t"
653
+ },
654
+ "source": [
655
+ "---\n",
656
+ "\n",
657
+ "**After download:**\n",
658
+ "1. Drop `training_results.png` into your project's `training_logs/` folder\n",
659
+ "2. Embed it in your README under a 'Training Results' section\n",
660
+ "3. Commit & push to your HF Space\n",
661
+ "4. You're done β€” switch to writing the mini-blog (Opus session)."
662
+ ],
663
+ "id": "w2wrkkAAxz1t"
664
+ }
665
+ ]
666
  }