DevanshuDon commited on
Commit
eb2095c
·
verified ·
1 Parent(s): 729fef6

Upload train_colab.ipynb

Browse files
Files changed (1) hide show
  1. train_colab.ipynb +530 -0
train_colab.ipynb ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {
6
+ "display_name": "Python 3",
7
+ "language": "python",
8
+ "name": "python3"
9
+ },
10
+ "language_info": {
11
+ "name": "python",
12
+ "version": "3.10"
13
+ },
14
+ "accelerator": "GPU",
15
+ "colab": {
16
+ "provenance": [],
17
+ "gpuType": "T4"
18
+ }
19
+ },
20
+ "cells": [
21
+ {
22
+ "cell_type": "markdown",
23
+ "metadata": {},
24
+ "source": [
25
+ "# ExecAssist \u2014 GRPO Training on Colab T4\n",
26
+ "\n",
27
+ "Trains Qwen2.5-0.5B-Instruct with TRL GRPO on the deployed ExecAssist environment.\n",
28
+ "\n",
29
+ "**Runtime:** Set runtime \u2192 T4 GPU. Total run ~45 min.\n",
30
+ "\n",
31
+ "**Outputs:** `training_results.png` + `results.json` for your README."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {},
37
+ "source": [
38
+ "## 1. Install + GPU check"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "metadata": {},
44
+ "execution_count": null,
45
+ "outputs": [],
46
+ "source": [
47
+ "!pip install -q -U torch transformers trl datasets accelerate matplotlib"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "metadata": {},
53
+ "execution_count": null,
54
+ "outputs": [],
55
+ "source": [
56
+ "import torch\n",
57
+ "assert torch.cuda.is_available(), 'No GPU detected \u2014 set Runtime > Change runtime type > T4 GPU'\n",
58
+ "print('GPU:', torch.cuda.get_device_name(0))\n",
59
+ "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9, 1), 'GB')"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "## 2. Config"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "metadata": {},
72
+ "execution_count": null,
73
+ "outputs": [],
74
+ "source": [
75
+ "import os, json, requests, copy\n",
76
+ "from datetime import datetime\n",
77
+ "\n",
78
+ "ENV_URL = 'https://devanshudon-exec-assist.hf.space'\n",
79
+ "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
80
+ "N_PER_TASK = 30 # 30 each = 90 scenarios total\n",
81
+ "N_EVAL = 10 # baseline/trained eval samples per task\n",
82
+ "OUT_DIR = 'training_logs'\n",
83
+ "os.makedirs(OUT_DIR, exist_ok=True)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "## 3. Collect scenarios from your HF Space\n",
91
+ "\n",
92
+ "Each `/reset` returns a fresh scenario. We pull 90 of them once and reuse them for both training and evaluation \u2014 this keeps reward scoring deterministic."
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "metadata": {},
98
+ "execution_count": null,
99
+ "outputs": [],
100
+ "source": [
101
+ "def reset_env(task):\n",
102
+ " r = requests.post(f'{ENV_URL}/reset', params={'task': task}, timeout=60)\n",
103
+ " r.raise_for_status()\n",
104
+ " return r.json()\n",
105
+ "\n",
106
+ "scenarios = []\n",
107
+ "for task in ['easy', 'medium', 'hard']:\n",
108
+ " for i in range(N_PER_TASK):\n",
109
+ " try:\n",
110
+ " data = reset_env(task)\n",
111
+ " scenarios.append({'task': task, 'observation': data.get('observation', {})})\n",
112
+ " except Exception as e:\n",
113
+ " print(f' ! {task}#{i}: {e}')\n",
114
+ " print(f' \u2713 {task}: {sum(1 for s in scenarios if s[\"task\"]==task)} collected')\n",
115
+ "\n",
116
+ "print(f'\\nTotal scenarios: {len(scenarios)}')"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "markdown",
121
+ "metadata": {},
122
+ "source": [
123
+ "## 4. Reward scoring (heuristic, deterministic, no API calls)\n",
124
+ "\n",
125
+ "Mirrors the env's weightings \u2014 Easy 50/50, Medium 30/40/30, Hard 34/33/33 \u2014 without the LLM judge so training stays fast."
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "metadata": {},
131
+ "execution_count": null,
132
+ "outputs": [],
133
+ "source": [
134
+ "def score_email(email):\n",
135
+ " if not email or len(email.strip()) < 20: return 0.0\n",
136
+ " e, score = email.lower(), 0.5\n",
137
+ " if any(w in e for w in ['thank','appreciate','please','regards','sincerely','best']): score += 0.15\n",
138
+ " if any(e.lstrip().startswith(g) for g in ['hi ','hello ','dear ','good morning','good afternoon']): score += 0.10\n",
139
+ " if any(c in e for c in ['regards','sincerely','best,','thanks,','cheers']): score += 0.10\n",
140
+ " wc = len(email.split())\n",
141
+ " if 25 <= wc <= 200: score += 0.15\n",
142
+ " elif wc > 250: score -= 0.10\n",
143
+ " elif wc < 15: score -= 0.20\n",
144
+ " return max(0.0, min(1.0, score))\n",
145
+ "\n",
146
+ "def score_scheduling(meeting, obs):\n",
147
+ " if not isinstance(meeting, dict): return 0.0\n",
148
+ " has_s = 'start_time' in meeting; has_e = 'end_time' in meeting\n",
149
+ " has_p = isinstance(meeting.get('participants'), list) and len(meeting.get('participants',[])) >= 2\n",
150
+ " has_subj = bool(meeting.get('subject'))\n",
151
+ " no_conflict = in_hours = good_dur = False\n",
152
+ " if has_s and has_e:\n",
153
+ " try:\n",
154
+ " ns = datetime.fromisoformat(meeting['start_time'])\n",
155
+ " ne = datetime.fromisoformat(meeting['end_time'])\n",
156
+ " in_hours = 9 <= ns.hour < 17\n",
157
+ " d = (ne - ns).total_seconds()/60\n",
158
+ " good_dur = 15 <= d <= 120\n",
159
+ " no_conflict = True\n",
160
+ " for m in obs.get('calendar', {}).get('existing_meetings', []):\n",
161
+ " try:\n",
162
+ " es = datetime.fromisoformat(m['start_time'])\n",
163
+ " ee = datetime.fromisoformat(m['end_time'])\n",
164
+ " if not (ne <= es or ns >= ee):\n",
165
+ " no_conflict = False; break\n",
166
+ " except: pass\n",
167
+ " except: pass\n",
168
+ " checks = [has_s, has_e, has_p, has_subj, no_conflict, in_hours, good_dur]\n",
169
+ " return sum(checks) / len(checks)\n",
170
+ "\n",
171
+ "def score_conflict(action, obs):\n",
172
+ " cal_action = action.get('calendar_action','')\n",
173
+ " meeting = action.get('meeting_details') or {}\n",
174
+ " email = (action.get('email_reply') or '').lower()\n",
175
+ " score = 0.5\n",
176
+ " if any(w in email for w in ['conflict','unavailable','alternative','instead','however','unfortunately']): score += 0.15\n",
177
+ " alts = meeting.get('proposed_alternatives') if isinstance(meeting, dict) else None\n",
178
+ " if alts and isinstance(alts, list) and len(alts) >= 2: score += 0.25\n",
179
+ " if cal_action == 'propose_alternatives': score += 0.10\n",
180
+ " return max(0.0, min(1.0, score))\n",
181
+ "\n",
182
+ "def penalties(action):\n",
183
+ " p = 0.0\n",
184
+ " email = action.get('email_reply','') or ''\n",
185
+ " if not email or len(email.strip()) < 20: p -= 0.30\n",
186
+ " if len(email) > 1500: p -= 0.15\n",
187
+ " if not action.get('meeting_details'): p -= 0.40\n",
188
+ " return p\n",
189
+ "\n",
190
+ "def compose_reward(action, obs, task):\n",
191
+ " if not isinstance(action, dict): return 0.0\n",
192
+ " e = score_email(action.get('email_reply',''))\n",
193
+ " s = score_scheduling(action.get('meeting_details'), obs)\n",
194
+ " c = score_conflict(action, obs)\n",
195
+ " p = penalties(action)\n",
196
+ " if task == 'easy': base = 0.50*e + 0.50*s\n",
197
+ " elif task == 'medium': base = 0.30*e + 0.40*c + 0.30*s\n",
198
+ " else: base = 0.34*e + 0.33*s + 0.33*c\n",
199
+ " return max(0.0, min(1.0, base + p))\n",
200
+ "\n",
201
+ "# Sanity check\n",
202
+ "demo = {'email_reply':'Hi Sarah, Thank you for reaching out. Monday at 2pm works well. Best regards, Alex Chen',\n",
203
+ " 'calendar_action':'book',\n",
204
+ " 'meeting_details':{'participants':['s@x.com','a@x.com'],\n",
205
+ " 'start_time':'2026-04-28T14:00:00',\n",
206
+ " 'end_time':'2026-04-28T14:30:00',\n",
207
+ " 'subject':'Test'}}\n",
208
+ "print('Demo reward (easy):', compose_reward(demo, scenarios[0]['observation'], 'easy'))"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {},
214
+ "source": [
215
+ "## 5. Build prompts + HF Dataset"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "metadata": {},
221
+ "execution_count": null,
222
+ "outputs": [],
223
+ "source": [
224
+ "def build_prompt(obs):\n",
225
+ " emails = obs.get('emails', [])\n",
226
+ " cal = obs.get('calendar', {})\n",
227
+ " em_str = ''\n",
228
+ " for e in emails:\n",
229
+ " em_str += f\"From: {e.get('sender','?')}\\nSubject: {e.get('subject','?')}\\nPriority: {e.get('priority','normal')}\\n{e.get('body','')}\\n---\\n\"\n",
230
+ " meetings = cal.get('existing_meetings', [])\n",
231
+ " cal_str = 'Existing meetings:\\n' + ('\\n'.join(f\" - {m.get('subject','?')}: {m.get('start_time','?')} -> {m.get('end_time','?')}\" for m in meetings) if meetings else ' (none)')\n",
232
+ " exec_name = cal.get('executive_name','Alex Chen')\n",
233
+ " return f\"\"\"You are {exec_name}'s executive assistant. Reply to incoming email and book the meeting.\n",
234
+ "\n",
235
+ "EMAILS:\n",
236
+ "{em_str}\n",
237
+ "\n",
238
+ "{cal_str}\n",
239
+ "\n",
240
+ "Working hours: 9-17. Duration: 15-120 min. Avoid double-booking.\n",
241
+ "\n",
242
+ "Respond with ONLY a JSON object, no commentary:\n",
243
+ "{{\\\"email_reply\\\":\\\"...\\\",\\\"calendar_action\\\":\\\"book\\\",\\\"meeting_details\\\":{{\\\"participants\\\":[\\\"a@x.com\\\",\\\"b@x.com\\\"],\\\"start_time\\\":\\\"2026-04-28T14:00:00\\\",\\\"end_time\\\":\\\"2026-04-28T14:30:00\\\",\\\"subject\\\":\\\"...\\\"}}}}\"\"\"\n",
244
+ "\n",
245
+ "from datasets import Dataset\n",
246
+ "ds = Dataset.from_dict({\n",
247
+ " 'prompt': [build_prompt(s['observation']) for s in scenarios],\n",
248
+ " 'task': [s['task'] for s in scenarios],\n",
249
+ " 'scenario': [s['observation'] for s in scenarios],\n",
250
+ "})\n",
251
+ "print(f'Dataset: {len(ds)} rows')\n",
252
+ "print('---\\nSample prompt (truncated):\\n', ds[0]['prompt'][:400], '...')"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "metadata": {},
258
+ "source": [
259
+ "## 6. JSON parser + GRPO reward function\n",
260
+ "\n",
261
+ "GRPOTrainer passes extra dataset columns (`task`, `scenario`) as kwargs to the reward function."
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "metadata": {},
267
+ "execution_count": null,
268
+ "outputs": [],
269
+ "source": [
270
+ "def parse_action(text):\n",
271
+ " try:\n",
272
+ " s = text.find('{')\n",
273
+ " if s == -1: return {}\n",
274
+ " depth = 0\n",
275
+ " for i, ch in enumerate(text[s:], start=s):\n",
276
+ " if ch == '{': depth += 1\n",
277
+ " elif ch == '}':\n",
278
+ " depth -= 1\n",
279
+ " if depth == 0:\n",
280
+ " return json.loads(text[s:i+1])\n",
281
+ " except Exception:\n",
282
+ " return {}\n",
283
+ " return {}\n",
284
+ "\n",
285
+ "def reward_function(completions, scenario, task, **kwargs):\n",
286
+ " rewards = []\n",
287
+ " for comp, scen, t in zip(completions, scenario, task):\n",
288
+ " text = comp[-1]['content'] if isinstance(comp, list) else comp\n",
289
+ " action = parse_action(text)\n",
290
+ " try:\n",
291
+ " r = compose_reward(action, scen, t)\n",
292
+ " except Exception:\n",
293
+ " r = 0.0\n",
294
+ " rewards.append(float(r))\n",
295
+ " return rewards\n",
296
+ "\n",
297
+ "# Smoke test\n",
298
+ "demo_comp = json.dumps(demo)\n",
299
+ "print('Reward fn output:', reward_function([demo_comp], [scenarios[0]['observation']], ['easy']))"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "metadata": {},
305
+ "source": [
306
+ "## 7. Load model + tokenizer"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "metadata": {},
312
+ "execution_count": null,
313
+ "outputs": [],
314
+ "source": [
315
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
316
+ "\n",
317
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
318
+ "if tokenizer.pad_token is None:\n",
319
+ " tokenizer.pad_token = tokenizer.eos_token\n",
320
+ "\n",
321
+ "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to('cuda')\n",
322
+ "print(f'Loaded {MODEL_NAME} \u2014 {sum(p.numel() for p in model.parameters())/1e6:.0f}M params')"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {},
328
+ "source": [
329
+ "## 8. Baseline evaluation (untrained Qwen-0.5B)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "metadata": {},
335
+ "execution_count": null,
336
+ "outputs": [],
337
+ "source": [
338
+ "def evaluate(m, tok, dataset, n_per_task=10):\n",
339
+ " m.eval()\n",
340
+ " results = {'easy':[], 'medium':[], 'hard':[]}\n",
341
+ " for ex in dataset:\n",
342
+ " t = ex['task']\n",
343
+ " if len(results[t]) >= n_per_task: continue\n",
344
+ " msgs = [{'role':'user','content': ex['prompt']}]\n",
345
+ " text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)\n",
346
+ " ids = tok(text, return_tensors='pt', truncation=True, max_length=1024).to('cuda')\n",
347
+ " with torch.no_grad():\n",
348
+ " out = m.generate(**ids, max_new_tokens=350, do_sample=True, temperature=0.7,\n",
349
+ " top_p=0.9, pad_token_id=tok.eos_token_id)\n",
350
+ " comp = tok.decode(out[0][ids['input_ids'].shape[1]:], skip_special_tokens=True)\n",
351
+ " action = parse_action(comp)\n",
352
+ " r = compose_reward(action, ex['scenario'], t) if action else 0.0\n",
353
+ " results[t].append(r)\n",
354
+ " return results\n",
355
+ "\n",
356
+ "print('Evaluating baseline...')\n",
357
+ "baseline = evaluate(model, tokenizer, ds, n_per_task=N_EVAL)\n",
358
+ "for t, rs in baseline.items():\n",
359
+ " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "markdown",
364
+ "metadata": {},
365
+ "source": [
366
+ "## 9. Train with TRL GRPO\n",
367
+ "\n",
368
+ "Hyperparameters chosen for T4 + 0.5B model. Should complete in ~25-35 min."
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "metadata": {},
374
+ "execution_count": null,
375
+ "outputs": [],
376
+ "source": [
377
+ "from trl import GRPOTrainer, GRPOConfig\n",
378
+ "\n",
379
+ "cfg = GRPOConfig(\n",
380
+ " output_dir='./grpo_out',\n",
381
+ " learning_rate=5e-6,\n",
382
+ " per_device_train_batch_size=2,\n",
383
+ " gradient_accumulation_steps=4,\n",
384
+ " num_generations=4,\n",
385
+ " max_prompt_length=768,\n",
386
+ " max_completion_length=300,\n",
387
+ " num_train_epochs=1,\n",
388
+ " logging_steps=1,\n",
389
+ " save_strategy='no',\n",
390
+ " fp16=True,\n",
391
+ " report_to='none',\n",
392
+ " temperature=0.9,\n",
393
+ " beta=0.04,\n",
394
+ " gradient_checkpointing=True,\n",
395
+ ")\n",
396
+ "\n",
397
+ "trainer = GRPOTrainer(\n",
398
+ " model=model,\n",
399
+ " args=cfg,\n",
400
+ " train_dataset=ds,\n",
401
+ " reward_funcs=reward_function,\n",
402
+ " processing_class=tokenizer,\n",
403
+ ")\n",
404
+ "\n",
405
+ "print('Starting GRPO training...')\n",
406
+ "trainer.train()\n",
407
+ "print('\u2705 Training complete')"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "markdown",
412
+ "metadata": {},
413
+ "source": [
414
+ "## 10. Post-training evaluation"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "metadata": {},
420
+ "execution_count": null,
421
+ "outputs": [],
422
+ "source": [
423
+ "print('Evaluating trained model...')\n",
424
+ "trained = evaluate(trainer.model, tokenizer, ds, n_per_task=N_EVAL)\n",
425
+ "for t, rs in trained.items():\n",
426
+ " print(f' {t:<7} avg={sum(rs)/max(len(rs),1):.3f} n={len(rs)}')\n",
427
+ "\n",
428
+ "# Compare\n",
429
+ "print('\\n Baseline \u2192 Trained')\n",
430
+ "for t in ['easy','medium','hard']:\n",
431
+ " b = sum(baseline[t])/max(len(baseline[t]),1)\n",
432
+ " tr = sum(trained[t])/max(len(trained[t]),1)\n",
433
+ " delta = (tr - b) / max(b, 1e-6) * 100\n",
434
+ " print(f' {t:<7} {b:.3f} \u2192 {tr:.3f} ({delta:+.1f}%)')"
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "markdown",
439
+ "metadata": {},
440
+ "source": [
441
+ "## 11. Plots + save"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "code",
446
+ "metadata": {},
447
+ "execution_count": null,
448
+ "outputs": [],
449
+ "source": [
450
+ "import matplotlib.pyplot as plt\n",
451
+ "import numpy as np\n",
452
+ "\n",
453
+ "tasks = ['easy','medium','hard']\n",
454
+ "b_avg = [np.mean(baseline[t]) for t in tasks]\n",
455
+ "t_avg = [np.mean(trained[t]) for t in tasks]\n",
456
+ "\n",
457
+ "log = trainer.state.log_history\n",
458
+ "train_steps = [e.get('step') for e in log if 'reward' in e]\n",
459
+ "train_rewards = [e.get('reward') for e in log if 'reward' in e]\n",
460
+ "\n",
461
+ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
462
+ "\n",
463
+ "# Bar: baseline vs trained\n",
464
+ "x = np.arange(len(tasks)); w = 0.35\n",
465
+ "axes[0].bar(x - w/2, b_avg, w, label='Baseline (Qwen2.5-0.5B, untrained)', color='#e74c3c', alpha=0.85)\n",
466
+ "axes[0].bar(x + w/2, t_avg, w, label='Trained (GRPO)', color='#2ecc71', alpha=0.85)\n",
467
+ "axes[0].set_xticks(x); axes[0].set_xticklabels([t.capitalize() for t in tasks])\n",
468
+ "axes[0].set_ylabel('Mean reward'); axes[0].set_title('Baseline vs Trained \u2014 per task')\n",
469
+ "axes[0].legend(); axes[0].grid(axis='y', alpha=0.3); axes[0].set_ylim(0, 1)\n",
470
+ "for i, (b, tr) in enumerate(zip(b_avg, t_avg)):\n",
471
+ " axes[0].text(i - w/2, b + 0.02, f'{b:.3f}', ha='center', fontsize=9)\n",
472
+ " axes[0].text(i + w/2, tr + 0.02, f'{tr:.3f}', ha='center', fontsize=9)\n",
473
+ "\n",
474
+ "# Line: training reward over steps\n",
475
+ "if train_rewards:\n",
476
+ " axes[1].plot(train_steps, train_rewards, marker='o', color='#3498db', linewidth=2, markersize=4)\n",
477
+ " axes[1].set_xlabel('Training step'); axes[1].set_ylabel('Mean reward (batch)')\n",
478
+ " axes[1].set_title('GRPO \u2014 reward over training steps')\n",
479
+ " axes[1].grid(alpha=0.3)\n",
480
+ "else:\n",
481
+ " axes[1].text(0.5, 0.5, 'No training log captured', ha='center', va='center', transform=axes[1].transAxes)\n",
482
+ "\n",
483
+ "plt.tight_layout()\n",
484
+ "plt.savefig(f'{OUT_DIR}/training_results.png', dpi=150, bbox_inches='tight')\n",
485
+ "plt.show()\n",
486
+ "\n",
487
+ "with open(f'{OUT_DIR}/results.json', 'w') as f:\n",
488
+ " json.dump({\n",
489
+ " 'baseline': baseline,\n",
490
+ " 'trained': trained,\n",
491
+ " 'training_log': [{'step': s, 'reward': r} for s, r in zip(train_steps, train_rewards)],\n",
492
+ " 'config': {'model': MODEL_NAME, 'n_per_task': N_PER_TASK, 'num_generations': cfg.num_generations,\n",
493
+ " 'epochs': cfg.num_train_epochs, 'lr': cfg.learning_rate}\n",
494
+ " }, f, indent=2, default=str)\n",
495
+ "print(f'\u2713 Saved {OUT_DIR}/training_results.png and {OUT_DIR}/results.json')"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "markdown",
500
+ "metadata": {},
501
+ "source": [
502
+ "## 12. Download artifacts to your laptop"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "metadata": {},
508
+ "execution_count": null,
509
+ "outputs": [],
510
+ "source": [
511
+ "from google.colab import files\n",
512
+ "files.download(f'{OUT_DIR}/training_results.png')\n",
513
+ "files.download(f'{OUT_DIR}/results.json')"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "markdown",
518
+ "metadata": {},
519
+ "source": [
520
+ "---\n",
521
+ "\n",
522
+ "**After download:**\n",
523
+ "1. Drop `training_results.png` into your project's `training_logs/` folder\n",
524
+ "2. Embed it in your README under a 'Training Results' section\n",
525
+ "3. Commit & push to your HF Space\n",
526
+ "4. You're done \u2014 switch to writing the mini-blog (Opus session)."
527
+ ]
528
+ }
529
+ ]
530
+ }