Vighnesh commited on
Commit
a016315
·
1 Parent(s): 5d570d6

Remove redundant train_grpo_safe.ipynb

Browse files
Files changed (1) hide show
  1. train_grpo_safe.ipynb +0 -562
train_grpo_safe.ipynb DELETED
@@ -1,562 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "display_name": "Python 3",
11
- "name": "python3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU"
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "markdown",
21
- "metadata": {},
22
- "source": [
23
- "# Support Ticket Env - GRPO Fine-Tuning\n",
24
- "**OpenEnv x Scalar Hackathon**\n",
25
- "\n",
26
- "Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using GRPO (Group Relative Policy Optimization) from HuggingFace TRL against the live Support Ticket Environment API.\n",
27
- "\n",
28
- "- Model: Qwen2.5-0.5B-Instruct\n",
29
- "- Algorithm: GRPO\n",
30
- "- Environment: https://algocore-support-ticket-env.hf.space\n",
31
- "- Runtime: ~45-60 min on free Colab T4"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "!pip install -q trl transformers peft accelerate\n",
41
- "!pip install -q torch bitsandbytes requests datasets\n",
42
- "print('Done')"
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "metadata": {},
49
- "outputs": [],
50
- "source": [
51
- "import os\n",
52
- "\n",
53
- "HF_TOKEN = \"YOUR_HF_TOKEN_HERE\"\n",
54
- "ENV_BASE_URL = \"https://algocore-support-ticket-env.hf.space\"\n",
55
- "MODEL_NAME = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
56
- "OUTPUT_DIR = \"/content/support-ticket-grpo\"\n",
57
- "HF_REPO_ID = \"AlgoCore/support-ticket-grpo-model\"\n",
58
- "\n",
59
- "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
60
- "os.environ[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n",
61
- "\n",
62
- "import torch\n",
63
- "print(\"GPU:\", torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"NO GPU - switch runtime!\")\n",
64
- "if torch.cuda.is_available():\n",
65
- " print(\"VRAM:\", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), \"GB\")\n",
66
- "print(\"Model:\", MODEL_NAME)\n",
67
- "print(\"Env:\", ENV_BASE_URL)"
68
- ]
69
- },
70
- {
71
- "cell_type": "code",
72
- "execution_count": null,
73
- "metadata": {},
74
- "outputs": [],
75
- "source": [
76
- "import requests\n",
77
- "import json\n",
78
- "import re\n",
79
- "from dataclasses import dataclass\n",
80
- "from typing import Optional\n",
81
- "\n",
82
- "@dataclass\n",
83
- "class Obs:\n",
84
- " ticket_id: str\n",
85
- " ticket_text: str\n",
86
- " task_id: int\n",
87
- " current_category: Optional[str]\n",
88
- " resolved: bool\n",
89
- " step_count: int\n",
90
- " feedback: str\n",
91
- " score: float\n",
92
- " reward: float\n",
93
- " done: bool\n",
94
- "\n",
95
- "class SupportEnvClient:\n",
96
- " def __init__(self, base_url):\n",
97
- " self.base_url = base_url.rstrip('/')\n",
98
- " self.session = requests.Session()\n",
99
- " self.session.headers.update({'Content-Type': 'application/json'})\n",
100
- "\n",
101
- " def health(self):\n",
102
- " try:\n",
103
- " r = self.session.get(f\"{self.base_url}/health\", timeout=10)\n",
104
- " return r.status_code == 200\n",
105
- " except:\n",
106
- " return False\n",
107
- "\n",
108
- " def reset(self, task_id=1, seed=42):\n",
109
- " r = self.session.post(f\"{self.base_url}/reset\", json={\"task_id\": task_id, \"seed\": seed}, timeout=15)\n",
110
- " r.raise_for_status()\n",
111
- " return self._parse(r.json())\n",
112
- "\n",
113
- " def step(self, action):\n",
114
- " r = self.session.post(f\"{self.base_url}/step\", json={\"action\": action}, timeout=15)\n",
115
- " r.raise_for_status()\n",
116
- " return self._parse(r.json())\n",
117
- "\n",
118
- " def _parse(self, data):\n",
119
- " obs = data.get('observation', data)\n",
120
- " return Obs(\n",
121
- " ticket_id=obs.get('ticket_id', ''),\n",
122
- " ticket_text=obs.get('ticket_text', ''),\n",
123
- " task_id=obs.get('task_id', 1),\n",
124
- " current_category=obs.get('current_category'),\n",
125
- " resolved=obs.get('resolved', False),\n",
126
- " step_count=obs.get('step_count', 0),\n",
127
- " feedback=obs.get('feedback', ''),\n",
128
- " score=obs.get('score', 0.0),\n",
129
- " reward=obs.get('reward', 0.0),\n",
130
- " done=obs.get('done', False),\n",
131
- " )\n",
132
- "\n",
133
- "env_client = SupportEnvClient(ENV_BASE_URL)\n",
134
- "if env_client.health():\n",
135
- " print('Environment API reachable')\n",
136
- " obs = env_client.reset(task_id=1, seed=42)\n",
137
- " print(f'Ticket: {obs.ticket_id} - {obs.ticket_text[:70]}')\n",
138
- "else:\n",
139
- " print('Cannot reach environment - check ENV_BASE_URL')"
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "metadata": {},
146
- "outputs": [],
147
- "source": [
148
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
149
- "import torch\n",
150
- "\n",
151
- "print(f\"Loading {MODEL_NAME}...\")\n",
152
- "\n",
153
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, trust_remote_code=True)\n",
154
- "tokenizer.pad_token = tokenizer.eos_token\n",
155
- "tokenizer.padding_side = 'left'\n",
156
- "\n",
157
- "model = AutoModelForCausalLM.from_pretrained(\n",
158
- " MODEL_NAME,\n",
159
- " token=HF_TOKEN,\n",
160
- " torch_dtype=torch.float16,\n",
161
- " device_map='auto',\n",
162
- " trust_remote_code=True,\n",
163
- ")\n",
164
- "\n",
165
- "print(f'Model loaded - {sum(p.numel() for p in model.parameters())/1e6:.0f}M parameters')\n",
166
- "print(f'Device: {next(model.parameters()).device}')"
167
- ]
168
- },
169
- {
170
- "cell_type": "code",
171
- "execution_count": null,
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "from peft import LoraConfig, get_peft_model, TaskType\n",
176
- "\n",
177
- "lora_config = LoraConfig(\n",
178
- " task_type=TaskType.CAUSAL_LM,\n",
179
- " r=16,\n",
180
- " lora_alpha=32,\n",
181
- " lora_dropout=0.05,\n",
182
- " target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
183
- " bias=\"none\",\n",
184
- ")\n",
185
- "\n",
186
- "model = get_peft_model(model, lora_config)\n",
187
- "model.print_trainable_parameters()"
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": null,
193
- "metadata": {},
194
- "outputs": [],
195
- "source": [
196
- "SYSTEM_PROMPT = \"\"\"You are a customer support AI agent. Given a ticket, respond with a JSON action.\n",
197
- "\n",
198
- "Respond ONLY with valid JSON:\n",
199
- "{\"action_type\": \"classify\"|\"reply\"|\"escalate\"|\"close\", \"category\": \"billing\"|\"technical\"|\"account\"|\"general\"|\"refund\", \"reply_text\": \"...\", \"reason\": \"...\"}\n",
200
- "\n",
201
- "Rules:\n",
202
- "- Task 1: action_type=classify, pick correct category\n",
203
- "- Task 2: first classify, then reply/escalate/close\n",
204
- "- Task 3: classify each ticket then resolve it\n",
205
- "- category only needed for classify\n",
206
- "- reply_text only needed for reply\n",
207
- "- technical issues: escalate\n",
208
- "- resolved issues: close\n",
209
- "- billing/account/refund: reply\"\"\"\n",
210
- "\n",
211
- "def build_prompt(obs):\n",
212
- " user_msg = json.dumps({\n",
213
- " \"ticket_id\": obs.ticket_id,\n",
214
- " \"ticket_text\": obs.ticket_text,\n",
215
- " \"task_id\": obs.task_id,\n",
216
- " \"current_category\": obs.current_category,\n",
217
- " \"feedback\": obs.feedback,\n",
218
- " \"step_count\": obs.step_count,\n",
219
- " }, indent=2)\n",
220
- " messages = [\n",
221
- " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
222
- " {\"role\": \"user\", \"content\": user_msg},\n",
223
- " ]\n",
224
- " return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
225
- "\n",
226
- "def parse_action(text):\n",
227
- " text = text.strip()\n",
228
- " text = re.sub(r'^```(?:json)?\\s*', '', text)\n",
229
- " text = re.sub(r'\\s*```$', '', text)\n",
230
- " try:\n",
231
- " return json.loads(text)\n",
232
- " except:\n",
233
- " match = re.search(r'\\{.*?\\}', text, re.DOTALL)\n",
234
- " if match:\n",
235
- " try:\n",
236
- " return json.loads(match.group())\n",
237
- " except:\n",
238
- " pass\n",
239
- " return {\"action_type\": \"classify\", \"category\": \"general\"}\n",
240
- "\n",
241
- "obs = env_client.reset(task_id=1, seed=42)\n",
242
- "prompt = build_prompt(obs)\n",
243
- "print('Prompt builder OK')\n",
244
- "print(f'Prompt length: {len(prompt)} chars')"
245
- ]
246
- },
247
- {
248
- "cell_type": "code",
249
- "execution_count": null,
250
- "metadata": {},
251
- "outputs": [],
252
- "source": [
253
- "import random\n",
254
- "\n",
255
- "SEEDS = [42, 7, 123, 0, 99]\n",
256
- "TASK_IDS = [1, 2, 3]\n",
257
- "MAX_STEPS = 6\n",
258
- "\n",
259
- "def generate_action(prompt, max_new_tokens=150):\n",
260
- " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(model.device)\n",
261
- " with torch.no_grad():\n",
262
- " outputs = model.generate(\n",
263
- " **inputs,\n",
264
- " max_new_tokens=max_new_tokens,\n",
265
- " do_sample=True,\n",
266
- " temperature=0.7,\n",
267
- " top_p=0.9,\n",
268
- " pad_token_id=tokenizer.eos_token_id,\n",
269
- " )\n",
270
- " new_tokens = outputs[0][inputs['input_ids'].shape[1]:]\n",
271
- " return tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
272
- "\n",
273
- "def run_episode(task_id, seed):\n",
274
- " obs = env_client.reset(task_id=task_id, seed=seed)\n",
275
- " prompts, completions, rewards = [], [], []\n",
276
- " for _ in range(MAX_STEPS):\n",
277
- " if obs.done:\n",
278
- " break\n",
279
- " prompt = build_prompt(obs)\n",
280
- " completion = generate_action(prompt)\n",
281
- " action = parse_action(completion)\n",
282
- " try:\n",
283
- " obs = env_client.step(action)\n",
284
- " reward = float(obs.reward or 0.0)\n",
285
- " except:\n",
286
- " reward = -0.1\n",
287
- " obs.done = True\n",
288
- " prompts.append(prompt)\n",
289
- " completions.append(completion)\n",
290
- " rewards.append(reward)\n",
291
- " if obs.done:\n",
292
- " break\n",
293
- " return prompts, completions, sum(rewards)\n",
294
- "\n",
295
- "print('Running smoke test...')\n",
296
- "p, c, r = run_episode(task_id=1, seed=42)\n",
297
- "print(f'Smoke test passed - steps={len(p)}, total_reward={r:.3f}')"
298
- ]
299
- },
300
- {
301
- "cell_type": "code",
302
- "execution_count": null,
303
- "metadata": {},
304
- "outputs": [],
305
- "source": [
306
- "def evaluate(n_seeds=3):\n",
307
- " results = {}\n",
308
- " seeds = SEEDS[:n_seeds]\n",
309
- " for task_id in [1, 2, 3]:\n",
310
- " task_rewards = []\n",
311
- " for seed in seeds:\n",
312
- " _, _, total = run_episode(task_id=task_id, seed=seed)\n",
313
- " normalized = round(max(0, min(1, total / MAX_STEPS)), 3)\n",
314
- " task_rewards.append(normalized)\n",
315
- " avg = round(sum(task_rewards) / len(task_rewards), 3)\n",
316
- " results[f'task{task_id}'] = avg\n",
317
- " print(f' Task {task_id}: {avg:.3f}')\n",
318
- " results['overall'] = round(sum(results.values()) / 3, 3)\n",
319
- " print(f' Overall: {results[\"overall\"]:.3f}')\n",
320
- " return results\n",
321
- "\n",
322
- "print('=== BASELINE (before training) ===')\n",
323
- "baseline_scores = evaluate(n_seeds=3)"
324
- ]
325
- },
326
- {
327
- "cell_type": "code",
328
- "execution_count": null,
329
- "metadata": {},
330
- "outputs": [],
331
- "source": [
332
- "from torch.optim import AdamW\n",
333
- "from transformers import get_linear_schedule_with_warmup\n",
334
- "import numpy as np\n",
335
- "\n",
336
- "LEARNING_RATE = 5e-5\n",
337
- "N_EPISODES = 60\n",
338
- "GROUP_SIZE = 4\n",
339
- "KL_COEFF = 0.01\n",
340
- "GRAD_CLIP = 1.0\n",
341
- "LOG_EVERY = 5\n",
342
- "\n",
343
- "optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)\n",
344
- "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5, num_training_steps=N_EPISODES)\n",
345
- "\n",
346
- "training_log = []\n",
347
- "\n",
348
- "print(f'Starting GRPO training: {N_EPISODES} episodes, group_size={GROUP_SIZE}')\n",
349
- "print('=' * 60)\n",
350
- "\n",
351
- "model.train()\n",
352
- "\n",
353
- "for episode in range(1, N_EPISODES + 1):\n",
354
- " task_id = random.choice(TASK_IDS)\n",
355
- " seed = random.choice(SEEDS)\n",
356
- "\n",
357
- " group_rewards = []\n",
358
- " group_prompts = []\n",
359
- " group_completions = []\n",
360
- "\n",
361
- " for g in range(GROUP_SIZE):\n",
362
- " obs = env_client.reset(task_id=task_id, seed=seed)\n",
363
- " prompt = build_prompt(obs)\n",
364
- " completion = generate_action(prompt)\n",
365
- " action = parse_action(completion)\n",
366
- " try:\n",
367
- " obs = env_client.step(action)\n",
368
- " reward = float(obs.reward or 0.0)\n",
369
- " except:\n",
370
- " reward = -0.1\n",
371
- " group_rewards.append(reward)\n",
372
- " group_prompts.append(prompt)\n",
373
- " group_completions.append(completion)\n",
374
- "\n",
375
- " rewards_arr = np.array(group_rewards, dtype=np.float32)\n",
376
- " advantages = (rewards_arr - rewards_arr.mean()) / (rewards_arr.std() + 1e-8)\n",
377
- "\n",
378
- " total_loss = torch.tensor(0.0, requires_grad=True, device=model.device)\n",
379
- " optimizer.zero_grad()\n",
380
- "\n",
381
- " for prompt, completion, adv in zip(group_prompts, group_completions, advantages):\n",
382
- " if not completion.strip():\n",
383
- " continue\n",
384
- " full_text = prompt + completion\n",
385
- " inputs = tokenizer(full_text, return_tensors='pt', truncation=True, max_length=1200).to(model.device)\n",
386
- " prompt_len = tokenizer(prompt, return_tensors='pt')[\"input_ids\"].shape[1]\n",
387
- " outputs = model(**inputs, labels=inputs['input_ids'])\n",
388
- " logits = outputs.logits[:, prompt_len-1:-1, :]\n",
389
- " target_ids = inputs['input_ids'][:, prompt_len:]\n",
390
- " if target_ids.shape[1] == 0:\n",
391
- " continue\n",
392
- " log_probs = torch.nn.functional.log_softmax(logits, dim=-1)\n",
393
- " token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)\n",
394
- " seq_log_prob = token_log_probs.mean()\n",
395
- " pg_loss = -torch.tensor(float(adv), device=model.device) * seq_log_prob\n",
396
- " kl_loss = KL_COEFF * (seq_log_prob ** 2)\n",
397
- " total_loss = total_loss + (pg_loss + kl_loss) / GROUP_SIZE\n",
398
- "\n",
399
- " total_loss.backward()\n",
400
- " torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
401
- " optimizer.step()\n",
402
- " scheduler.step()\n",
403
- "\n",
404
- " avg_reward = float(rewards_arr.mean())\n",
405
- " training_log.append((episode, task_id, avg_reward))\n",
406
- "\n",
407
- " if episode % LOG_EVERY == 0:\n",
408
- " print(f'Episode {episode:3d}/{N_EPISODES} | task={task_id} | avg_reward={avg_reward:.3f} | loss={total_loss.item():.4f}')\n",
409
- "\n",
410
- "print('Training complete!')"
411
- ]
412
- },
413
- {
414
- "cell_type": "code",
415
- "execution_count": null,
416
- "metadata": {},
417
- "outputs": [],
418
- "source": [
419
- "model.eval()\n",
420
- "\n",
421
- "print('=== POST-TRAINING EVALUATION ===')\n",
422
- "trained_scores = evaluate(n_seeds=3)\n",
423
- "\n",
424
- "print('\\n=== IMPROVEMENT SUMMARY ===')\n",
425
- "print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
426
- "print('-' * 38)\n",
427
- "for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
428
- " b = baseline_scores.get(key, 0)\n",
429
- " a = trained_scores.get(key, 0)\n",
430
- " d = a - b\n",
431
- " print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')"
432
- ]
433
- },
434
- {
435
- "cell_type": "code",
436
- "execution_count": null,
437
- "metadata": {},
438
- "outputs": [],
439
- "source": [
440
- "import matplotlib.pyplot as plt\n",
441
- "import numpy as np\n",
442
- "\n",
443
- "episodes = [x[0] for x in training_log]\n",
444
- "task_ids = [x[1] for x in training_log]\n",
445
- "ep_rewards = [x[2] for x in training_log]\n",
446
- "\n",
447
- "def moving_avg(data, window=5):\n",
448
- " return np.convolve(data, np.ones(window)/window, mode='valid')\n",
449
- "\n",
450
- "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
451
- "fig.suptitle('Support Ticket Env - GRPO Training Results', fontsize=14, fontweight='bold')\n",
452
- "\n",
453
- "ax1 = axes[0]\n",
454
- "colors = {1: '#3498db', 2: '#2ecc71', 3: '#e74c3c'}\n",
455
- "for tid in [1, 2, 3]:\n",
456
- " mask = [i for i, t in enumerate(task_ids) if t == tid]\n",
457
- " if mask:\n",
458
- " x = [episodes[i] for i in mask]\n",
459
- " y = [ep_rewards[i] for i in mask]\n",
460
- " ax1.scatter(x, y, alpha=0.3, color=colors[tid], s=15)\n",
461
- " if len(y) >= 5:\n",
462
- " smoothed = moving_avg(y)\n",
463
- " ax1.plot(x[2:-2], smoothed, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
464
- " else:\n",
465
- " ax1.plot(x, y, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
466
- "\n",
467
- "ax1.set_xlabel('Episode')\n",
468
- "ax1.set_ylabel('Avg Reward')\n",
469
- "ax1.set_title('Training Reward per Episode')\n",
470
- "ax1.legend()\n",
471
- "ax1.grid(True, alpha=0.3)\n",
472
- "ax1.set_ylim(-0.1, 1.1)\n",
473
- "\n",
474
- "ax2 = axes[1]\n",
475
- "tasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\n",
476
- "keys = ['task1', 'task2', 'task3', 'overall']\n",
477
- "before_vals = [baseline_scores.get(k, 0) for k in keys]\n",
478
- "after_vals = [trained_scores.get(k, 0) for k in keys]\n",
479
- "\n",
480
- "x = np.arange(len(tasks))\n",
481
- "width = 0.35\n",
482
- "\n",
483
- "bars1 = ax2.bar(x - width/2, before_vals, width, label='Before Training', color='#95a5a6')\n",
484
- "bars2 = ax2.bar(x + width/2, after_vals, width, label='After GRPO', color='#2ecc71')\n",
485
- "\n",
486
- "for bar in bars1:\n",
487
- " ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
488
- " f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
489
- "for bar in bars2:\n",
490
- " ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
491
- " f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9,\n",
492
- " fontweight='bold', color='#27ae60')\n",
493
- "\n",
494
- "ax2.set_xticks(x)\n",
495
- "ax2.set_xticklabels(tasks)\n",
496
- "ax2.set_ylabel('Score (0-1)')\n",
497
- "ax2.set_title('Before vs After GRPO Training')\n",
498
- "ax2.legend()\n",
499
- "ax2.grid(True, alpha=0.3, axis='y')\n",
500
- "ax2.set_ylim(0, 1.15)\n",
501
- "\n",
502
- "plt.tight_layout()\n",
503
- "plt.savefig('/content/grpo_results.png', dpi=150, bbox_inches='tight')\n",
504
- "plt.show()\n",
505
- "print('Chart saved to /content/grpo_results.png')"
506
- ]
507
- },
508
- {
509
- "cell_type": "code",
510
- "execution_count": null,
511
- "metadata": {},
512
- "outputs": [],
513
- "source": [
514
- "import os\n",
515
- "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
516
- "\n",
517
- "model.save_pretrained(OUTPUT_DIR)\n",
518
- "tokenizer.save_pretrained(OUTPUT_DIR)\n",
519
- "print(f'Model saved to {OUTPUT_DIR}')\n",
520
- "\n",
521
- "try:\n",
522
- " from huggingface_hub import HfApi\n",
523
- " api = HfApi(token=HF_TOKEN)\n",
524
- " api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
525
- " api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n",
526
- " api.upload_file(path_or_fileobj='/content/grpo_results.png', path_in_repo='grpo_results.png', repo_id=HF_REPO_ID, repo_type='model')\n",
527
- " print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
528
- "except Exception as e:\n",
529
- " print(f'Push failed: {e}')\n",
530
- " print(f'Model is saved locally at {OUTPUT_DIR}')"
531
- ]
532
- },
533
- {
534
- "cell_type": "code",
535
- "execution_count": null,
536
- "metadata": {},
537
- "outputs": [],
538
- "source": [
539
- "from google.colab import files\n",
540
- "files.download('/content/grpo_results.png')\n",
541
- "\n",
542
- "print('\\n' + '='*50)\n",
543
- "print('FINAL TRAINING SUMMARY')\n",
544
- "print('='*50)\n",
545
- "print(f'Model: {MODEL_NAME}')\n",
546
- "print(f'Algorithm: GRPO')\n",
547
- "print(f'Episodes: {N_EPISODES}')\n",
548
- "print(f'Env: {ENV_BASE_URL}')\n",
549
- "print()\n",
550
- "print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
551
- "print('-' * 38)\n",
552
- "for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
553
- " b = baseline_scores.get(key, 0)\n",
554
- " a = trained_scores.get(key, 0)\n",
555
- " d = a - b\n",
556
- " print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')\n",
557
- "print('='*50)\n",
558
- "print(f'Model: https://huggingface.co/{HF_REPO_ID}')"
559
- ]
560
- }
561
- ]
562
- }