Spaces:
Sleeping
Sleeping
Vighnesh commited on
Commit ·
a016315
1
Parent(s): 5d570d6
Remove redundant train_grpo_safe.ipynb
Browse files- train_grpo_safe.ipynb +0 -562
train_grpo_safe.ipynb
DELETED
|
@@ -1,562 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"nbformat": 4,
|
| 3 |
-
"nbformat_minor": 0,
|
| 4 |
-
"metadata": {
|
| 5 |
-
"colab": {
|
| 6 |
-
"provenance": [],
|
| 7 |
-
"gpuType": "T4"
|
| 8 |
-
},
|
| 9 |
-
"kernelspec": {
|
| 10 |
-
"display_name": "Python 3",
|
| 11 |
-
"name": "python3"
|
| 12 |
-
},
|
| 13 |
-
"language_info": {
|
| 14 |
-
"name": "python"
|
| 15 |
-
},
|
| 16 |
-
"accelerator": "GPU"
|
| 17 |
-
},
|
| 18 |
-
"cells": [
|
| 19 |
-
{
|
| 20 |
-
"cell_type": "markdown",
|
| 21 |
-
"metadata": {},
|
| 22 |
-
"source": [
|
| 23 |
-
"# Support Ticket Env - GRPO Fine-Tuning\n",
|
| 24 |
-
"**OpenEnv x Scalar Hackathon**\n",
|
| 25 |
-
"\n",
|
| 26 |
-
"Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using GRPO (Group Relative Policy Optimization) from HuggingFace TRL against the live Support Ticket Environment API.\n",
|
| 27 |
-
"\n",
|
| 28 |
-
"- Model: Qwen2.5-0.5B-Instruct\n",
|
| 29 |
-
"- Algorithm: GRPO\n",
|
| 30 |
-
"- Environment: https://algocore-support-ticket-env.hf.space\n",
|
| 31 |
-
"- Runtime: ~45-60 min on free Colab T4"
|
| 32 |
-
]
|
| 33 |
-
},
|
| 34 |
-
{
|
| 35 |
-
"cell_type": "code",
|
| 36 |
-
"execution_count": null,
|
| 37 |
-
"metadata": {},
|
| 38 |
-
"outputs": [],
|
| 39 |
-
"source": [
|
| 40 |
-
"!pip install -q trl transformers peft accelerate\n",
|
| 41 |
-
"!pip install -q torch bitsandbytes requests datasets\n",
|
| 42 |
-
"print('Done')"
|
| 43 |
-
]
|
| 44 |
-
},
|
| 45 |
-
{
|
| 46 |
-
"cell_type": "code",
|
| 47 |
-
"execution_count": null,
|
| 48 |
-
"metadata": {},
|
| 49 |
-
"outputs": [],
|
| 50 |
-
"source": [
|
| 51 |
-
"import os\n",
|
| 52 |
-
"\n",
|
| 53 |
-
"HF_TOKEN = \"YOUR_HF_TOKEN_HERE\"\n",
|
| 54 |
-
"ENV_BASE_URL = \"https://algocore-support-ticket-env.hf.space\"\n",
|
| 55 |
-
"MODEL_NAME = \"Qwen/Qwen2.5-0.5B-Instruct\"\n",
|
| 56 |
-
"OUTPUT_DIR = \"/content/support-ticket-grpo\"\n",
|
| 57 |
-
"HF_REPO_ID = \"AlgoCore/support-ticket-grpo-model\"\n",
|
| 58 |
-
"\n",
|
| 59 |
-
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 60 |
-
"os.environ[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n",
|
| 61 |
-
"\n",
|
| 62 |
-
"import torch\n",
|
| 63 |
-
"print(\"GPU:\", torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"NO GPU - switch runtime!\")\n",
|
| 64 |
-
"if torch.cuda.is_available():\n",
|
| 65 |
-
" print(\"VRAM:\", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), \"GB\")\n",
|
| 66 |
-
"print(\"Model:\", MODEL_NAME)\n",
|
| 67 |
-
"print(\"Env:\", ENV_BASE_URL)"
|
| 68 |
-
]
|
| 69 |
-
},
|
| 70 |
-
{
|
| 71 |
-
"cell_type": "code",
|
| 72 |
-
"execution_count": null,
|
| 73 |
-
"metadata": {},
|
| 74 |
-
"outputs": [],
|
| 75 |
-
"source": [
|
| 76 |
-
"import requests\n",
|
| 77 |
-
"import json\n",
|
| 78 |
-
"import re\n",
|
| 79 |
-
"from dataclasses import dataclass\n",
|
| 80 |
-
"from typing import Optional\n",
|
| 81 |
-
"\n",
|
| 82 |
-
"@dataclass\n",
|
| 83 |
-
"class Obs:\n",
|
| 84 |
-
" ticket_id: str\n",
|
| 85 |
-
" ticket_text: str\n",
|
| 86 |
-
" task_id: int\n",
|
| 87 |
-
" current_category: Optional[str]\n",
|
| 88 |
-
" resolved: bool\n",
|
| 89 |
-
" step_count: int\n",
|
| 90 |
-
" feedback: str\n",
|
| 91 |
-
" score: float\n",
|
| 92 |
-
" reward: float\n",
|
| 93 |
-
" done: bool\n",
|
| 94 |
-
"\n",
|
| 95 |
-
"class SupportEnvClient:\n",
|
| 96 |
-
" def __init__(self, base_url):\n",
|
| 97 |
-
" self.base_url = base_url.rstrip('/')\n",
|
| 98 |
-
" self.session = requests.Session()\n",
|
| 99 |
-
" self.session.headers.update({'Content-Type': 'application/json'})\n",
|
| 100 |
-
"\n",
|
| 101 |
-
" def health(self):\n",
|
| 102 |
-
" try:\n",
|
| 103 |
-
" r = self.session.get(f\"{self.base_url}/health\", timeout=10)\n",
|
| 104 |
-
" return r.status_code == 200\n",
|
| 105 |
-
" except:\n",
|
| 106 |
-
" return False\n",
|
| 107 |
-
"\n",
|
| 108 |
-
" def reset(self, task_id=1, seed=42):\n",
|
| 109 |
-
" r = self.session.post(f\"{self.base_url}/reset\", json={\"task_id\": task_id, \"seed\": seed}, timeout=15)\n",
|
| 110 |
-
" r.raise_for_status()\n",
|
| 111 |
-
" return self._parse(r.json())\n",
|
| 112 |
-
"\n",
|
| 113 |
-
" def step(self, action):\n",
|
| 114 |
-
" r = self.session.post(f\"{self.base_url}/step\", json={\"action\": action}, timeout=15)\n",
|
| 115 |
-
" r.raise_for_status()\n",
|
| 116 |
-
" return self._parse(r.json())\n",
|
| 117 |
-
"\n",
|
| 118 |
-
" def _parse(self, data):\n",
|
| 119 |
-
" obs = data.get('observation', data)\n",
|
| 120 |
-
" return Obs(\n",
|
| 121 |
-
" ticket_id=obs.get('ticket_id', ''),\n",
|
| 122 |
-
" ticket_text=obs.get('ticket_text', ''),\n",
|
| 123 |
-
" task_id=obs.get('task_id', 1),\n",
|
| 124 |
-
" current_category=obs.get('current_category'),\n",
|
| 125 |
-
" resolved=obs.get('resolved', False),\n",
|
| 126 |
-
" step_count=obs.get('step_count', 0),\n",
|
| 127 |
-
" feedback=obs.get('feedback', ''),\n",
|
| 128 |
-
" score=obs.get('score', 0.0),\n",
|
| 129 |
-
" reward=obs.get('reward', 0.0),\n",
|
| 130 |
-
" done=obs.get('done', False),\n",
|
| 131 |
-
" )\n",
|
| 132 |
-
"\n",
|
| 133 |
-
"env_client = SupportEnvClient(ENV_BASE_URL)\n",
|
| 134 |
-
"if env_client.health():\n",
|
| 135 |
-
" print('Environment API reachable')\n",
|
| 136 |
-
" obs = env_client.reset(task_id=1, seed=42)\n",
|
| 137 |
-
" print(f'Ticket: {obs.ticket_id} - {obs.ticket_text[:70]}')\n",
|
| 138 |
-
"else:\n",
|
| 139 |
-
" print('Cannot reach environment - check ENV_BASE_URL')"
|
| 140 |
-
]
|
| 141 |
-
},
|
| 142 |
-
{
|
| 143 |
-
"cell_type": "code",
|
| 144 |
-
"execution_count": null,
|
| 145 |
-
"metadata": {},
|
| 146 |
-
"outputs": [],
|
| 147 |
-
"source": [
|
| 148 |
-
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
| 149 |
-
"import torch\n",
|
| 150 |
-
"\n",
|
| 151 |
-
"print(f\"Loading {MODEL_NAME}...\")\n",
|
| 152 |
-
"\n",
|
| 153 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN, trust_remote_code=True)\n",
|
| 154 |
-
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 155 |
-
"tokenizer.padding_side = 'left'\n",
|
| 156 |
-
"\n",
|
| 157 |
-
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 158 |
-
" MODEL_NAME,\n",
|
| 159 |
-
" token=HF_TOKEN,\n",
|
| 160 |
-
" torch_dtype=torch.float16,\n",
|
| 161 |
-
" device_map='auto',\n",
|
| 162 |
-
" trust_remote_code=True,\n",
|
| 163 |
-
")\n",
|
| 164 |
-
"\n",
|
| 165 |
-
"print(f'Model loaded - {sum(p.numel() for p in model.parameters())/1e6:.0f}M parameters')\n",
|
| 166 |
-
"print(f'Device: {next(model.parameters()).device}')"
|
| 167 |
-
]
|
| 168 |
-
},
|
| 169 |
-
{
|
| 170 |
-
"cell_type": "code",
|
| 171 |
-
"execution_count": null,
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"outputs": [],
|
| 174 |
-
"source": [
|
| 175 |
-
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 176 |
-
"\n",
|
| 177 |
-
"lora_config = LoraConfig(\n",
|
| 178 |
-
" task_type=TaskType.CAUSAL_LM,\n",
|
| 179 |
-
" r=16,\n",
|
| 180 |
-
" lora_alpha=32,\n",
|
| 181 |
-
" lora_dropout=0.05,\n",
|
| 182 |
-
" target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\"],\n",
|
| 183 |
-
" bias=\"none\",\n",
|
| 184 |
-
")\n",
|
| 185 |
-
"\n",
|
| 186 |
-
"model = get_peft_model(model, lora_config)\n",
|
| 187 |
-
"model.print_trainable_parameters()"
|
| 188 |
-
]
|
| 189 |
-
},
|
| 190 |
-
{
|
| 191 |
-
"cell_type": "code",
|
| 192 |
-
"execution_count": null,
|
| 193 |
-
"metadata": {},
|
| 194 |
-
"outputs": [],
|
| 195 |
-
"source": [
|
| 196 |
-
"SYSTEM_PROMPT = \"\"\"You are a customer support AI agent. Given a ticket, respond with a JSON action.\n",
|
| 197 |
-
"\n",
|
| 198 |
-
"Respond ONLY with valid JSON:\n",
|
| 199 |
-
"{\"action_type\": \"classify\"|\"reply\"|\"escalate\"|\"close\", \"category\": \"billing\"|\"technical\"|\"account\"|\"general\"|\"refund\", \"reply_text\": \"...\", \"reason\": \"...\"}\n",
|
| 200 |
-
"\n",
|
| 201 |
-
"Rules:\n",
|
| 202 |
-
"- Task 1: action_type=classify, pick correct category\n",
|
| 203 |
-
"- Task 2: first classify, then reply/escalate/close\n",
|
| 204 |
-
"- Task 3: classify each ticket then resolve it\n",
|
| 205 |
-
"- category only needed for classify\n",
|
| 206 |
-
"- reply_text only needed for reply\n",
|
| 207 |
-
"- technical issues: escalate\n",
|
| 208 |
-
"- resolved issues: close\n",
|
| 209 |
-
"- billing/account/refund: reply\"\"\"\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"def build_prompt(obs):\n",
|
| 212 |
-
" user_msg = json.dumps({\n",
|
| 213 |
-
" \"ticket_id\": obs.ticket_id,\n",
|
| 214 |
-
" \"ticket_text\": obs.ticket_text,\n",
|
| 215 |
-
" \"task_id\": obs.task_id,\n",
|
| 216 |
-
" \"current_category\": obs.current_category,\n",
|
| 217 |
-
" \"feedback\": obs.feedback,\n",
|
| 218 |
-
" \"step_count\": obs.step_count,\n",
|
| 219 |
-
" }, indent=2)\n",
|
| 220 |
-
" messages = [\n",
|
| 221 |
-
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 222 |
-
" {\"role\": \"user\", \"content\": user_msg},\n",
|
| 223 |
-
" ]\n",
|
| 224 |
-
" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 225 |
-
"\n",
|
| 226 |
-
"def parse_action(text):\n",
|
| 227 |
-
" text = text.strip()\n",
|
| 228 |
-
" text = re.sub(r'^```(?:json)?\\s*', '', text)\n",
|
| 229 |
-
" text = re.sub(r'\\s*```$', '', text)\n",
|
| 230 |
-
" try:\n",
|
| 231 |
-
" return json.loads(text)\n",
|
| 232 |
-
" except:\n",
|
| 233 |
-
" match = re.search(r'\\{.*?\\}', text, re.DOTALL)\n",
|
| 234 |
-
" if match:\n",
|
| 235 |
-
" try:\n",
|
| 236 |
-
" return json.loads(match.group())\n",
|
| 237 |
-
" except:\n",
|
| 238 |
-
" pass\n",
|
| 239 |
-
" return {\"action_type\": \"classify\", \"category\": \"general\"}\n",
|
| 240 |
-
"\n",
|
| 241 |
-
"obs = env_client.reset(task_id=1, seed=42)\n",
|
| 242 |
-
"prompt = build_prompt(obs)\n",
|
| 243 |
-
"print('Prompt builder OK')\n",
|
| 244 |
-
"print(f'Prompt length: {len(prompt)} chars')"
|
| 245 |
-
]
|
| 246 |
-
},
|
| 247 |
-
{
|
| 248 |
-
"cell_type": "code",
|
| 249 |
-
"execution_count": null,
|
| 250 |
-
"metadata": {},
|
| 251 |
-
"outputs": [],
|
| 252 |
-
"source": [
|
| 253 |
-
"import random\n",
|
| 254 |
-
"\n",
|
| 255 |
-
"SEEDS = [42, 7, 123, 0, 99]\n",
|
| 256 |
-
"TASK_IDS = [1, 2, 3]\n",
|
| 257 |
-
"MAX_STEPS = 6\n",
|
| 258 |
-
"\n",
|
| 259 |
-
"def generate_action(prompt, max_new_tokens=150):\n",
|
| 260 |
-
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(model.device)\n",
|
| 261 |
-
" with torch.no_grad():\n",
|
| 262 |
-
" outputs = model.generate(\n",
|
| 263 |
-
" **inputs,\n",
|
| 264 |
-
" max_new_tokens=max_new_tokens,\n",
|
| 265 |
-
" do_sample=True,\n",
|
| 266 |
-
" temperature=0.7,\n",
|
| 267 |
-
" top_p=0.9,\n",
|
| 268 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 269 |
-
" )\n",
|
| 270 |
-
" new_tokens = outputs[0][inputs['input_ids'].shape[1]:]\n",
|
| 271 |
-
" return tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
|
| 272 |
-
"\n",
|
| 273 |
-
"def run_episode(task_id, seed):\n",
|
| 274 |
-
" obs = env_client.reset(task_id=task_id, seed=seed)\n",
|
| 275 |
-
" prompts, completions, rewards = [], [], []\n",
|
| 276 |
-
" for _ in range(MAX_STEPS):\n",
|
| 277 |
-
" if obs.done:\n",
|
| 278 |
-
" break\n",
|
| 279 |
-
" prompt = build_prompt(obs)\n",
|
| 280 |
-
" completion = generate_action(prompt)\n",
|
| 281 |
-
" action = parse_action(completion)\n",
|
| 282 |
-
" try:\n",
|
| 283 |
-
" obs = env_client.step(action)\n",
|
| 284 |
-
" reward = float(obs.reward or 0.0)\n",
|
| 285 |
-
" except:\n",
|
| 286 |
-
" reward = -0.1\n",
|
| 287 |
-
" obs.done = True\n",
|
| 288 |
-
" prompts.append(prompt)\n",
|
| 289 |
-
" completions.append(completion)\n",
|
| 290 |
-
" rewards.append(reward)\n",
|
| 291 |
-
" if obs.done:\n",
|
| 292 |
-
" break\n",
|
| 293 |
-
" return prompts, completions, sum(rewards)\n",
|
| 294 |
-
"\n",
|
| 295 |
-
"print('Running smoke test...')\n",
|
| 296 |
-
"p, c, r = run_episode(task_id=1, seed=42)\n",
|
| 297 |
-
"print(f'Smoke test passed - steps={len(p)}, total_reward={r:.3f}')"
|
| 298 |
-
]
|
| 299 |
-
},
|
| 300 |
-
{
|
| 301 |
-
"cell_type": "code",
|
| 302 |
-
"execution_count": null,
|
| 303 |
-
"metadata": {},
|
| 304 |
-
"outputs": [],
|
| 305 |
-
"source": [
|
| 306 |
-
"def evaluate(n_seeds=3):\n",
|
| 307 |
-
" results = {}\n",
|
| 308 |
-
" seeds = SEEDS[:n_seeds]\n",
|
| 309 |
-
" for task_id in [1, 2, 3]:\n",
|
| 310 |
-
" task_rewards = []\n",
|
| 311 |
-
" for seed in seeds:\n",
|
| 312 |
-
" _, _, total = run_episode(task_id=task_id, seed=seed)\n",
|
| 313 |
-
" normalized = round(max(0, min(1, total / MAX_STEPS)), 3)\n",
|
| 314 |
-
" task_rewards.append(normalized)\n",
|
| 315 |
-
" avg = round(sum(task_rewards) / len(task_rewards), 3)\n",
|
| 316 |
-
" results[f'task{task_id}'] = avg\n",
|
| 317 |
-
" print(f' Task {task_id}: {avg:.3f}')\n",
|
| 318 |
-
" results['overall'] = round(sum(results.values()) / 3, 3)\n",
|
| 319 |
-
" print(f' Overall: {results[\"overall\"]:.3f}')\n",
|
| 320 |
-
" return results\n",
|
| 321 |
-
"\n",
|
| 322 |
-
"print('=== BASELINE (before training) ===')\n",
|
| 323 |
-
"baseline_scores = evaluate(n_seeds=3)"
|
| 324 |
-
]
|
| 325 |
-
},
|
| 326 |
-
{
|
| 327 |
-
"cell_type": "code",
|
| 328 |
-
"execution_count": null,
|
| 329 |
-
"metadata": {},
|
| 330 |
-
"outputs": [],
|
| 331 |
-
"source": [
|
| 332 |
-
"from torch.optim import AdamW\n",
|
| 333 |
-
"from transformers import get_linear_schedule_with_warmup\n",
|
| 334 |
-
"import numpy as np\n",
|
| 335 |
-
"\n",
|
| 336 |
-
"LEARNING_RATE = 5e-5\n",
|
| 337 |
-
"N_EPISODES = 60\n",
|
| 338 |
-
"GROUP_SIZE = 4\n",
|
| 339 |
-
"KL_COEFF = 0.01\n",
|
| 340 |
-
"GRAD_CLIP = 1.0\n",
|
| 341 |
-
"LOG_EVERY = 5\n",
|
| 342 |
-
"\n",
|
| 343 |
-
"optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)\n",
|
| 344 |
-
"scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5, num_training_steps=N_EPISODES)\n",
|
| 345 |
-
"\n",
|
| 346 |
-
"training_log = []\n",
|
| 347 |
-
"\n",
|
| 348 |
-
"print(f'Starting GRPO training: {N_EPISODES} episodes, group_size={GROUP_SIZE}')\n",
|
| 349 |
-
"print('=' * 60)\n",
|
| 350 |
-
"\n",
|
| 351 |
-
"model.train()\n",
|
| 352 |
-
"\n",
|
| 353 |
-
"for episode in range(1, N_EPISODES + 1):\n",
|
| 354 |
-
" task_id = random.choice(TASK_IDS)\n",
|
| 355 |
-
" seed = random.choice(SEEDS)\n",
|
| 356 |
-
"\n",
|
| 357 |
-
" group_rewards = []\n",
|
| 358 |
-
" group_prompts = []\n",
|
| 359 |
-
" group_completions = []\n",
|
| 360 |
-
"\n",
|
| 361 |
-
" for g in range(GROUP_SIZE):\n",
|
| 362 |
-
" obs = env_client.reset(task_id=task_id, seed=seed)\n",
|
| 363 |
-
" prompt = build_prompt(obs)\n",
|
| 364 |
-
" completion = generate_action(prompt)\n",
|
| 365 |
-
" action = parse_action(completion)\n",
|
| 366 |
-
" try:\n",
|
| 367 |
-
" obs = env_client.step(action)\n",
|
| 368 |
-
" reward = float(obs.reward or 0.0)\n",
|
| 369 |
-
" except:\n",
|
| 370 |
-
" reward = -0.1\n",
|
| 371 |
-
" group_rewards.append(reward)\n",
|
| 372 |
-
" group_prompts.append(prompt)\n",
|
| 373 |
-
" group_completions.append(completion)\n",
|
| 374 |
-
"\n",
|
| 375 |
-
" rewards_arr = np.array(group_rewards, dtype=np.float32)\n",
|
| 376 |
-
" advantages = (rewards_arr - rewards_arr.mean()) / (rewards_arr.std() + 1e-8)\n",
|
| 377 |
-
"\n",
|
| 378 |
-
" total_loss = torch.tensor(0.0, requires_grad=True, device=model.device)\n",
|
| 379 |
-
" optimizer.zero_grad()\n",
|
| 380 |
-
"\n",
|
| 381 |
-
" for prompt, completion, adv in zip(group_prompts, group_completions, advantages):\n",
|
| 382 |
-
" if not completion.strip():\n",
|
| 383 |
-
" continue\n",
|
| 384 |
-
" full_text = prompt + completion\n",
|
| 385 |
-
" inputs = tokenizer(full_text, return_tensors='pt', truncation=True, max_length=1200).to(model.device)\n",
|
| 386 |
-
" prompt_len = tokenizer(prompt, return_tensors='pt')[\"input_ids\"].shape[1]\n",
|
| 387 |
-
" outputs = model(**inputs, labels=inputs['input_ids'])\n",
|
| 388 |
-
" logits = outputs.logits[:, prompt_len-1:-1, :]\n",
|
| 389 |
-
" target_ids = inputs['input_ids'][:, prompt_len:]\n",
|
| 390 |
-
" if target_ids.shape[1] == 0:\n",
|
| 391 |
-
" continue\n",
|
| 392 |
-
" log_probs = torch.nn.functional.log_softmax(logits, dim=-1)\n",
|
| 393 |
-
" token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)\n",
|
| 394 |
-
" seq_log_prob = token_log_probs.mean()\n",
|
| 395 |
-
" pg_loss = -torch.tensor(float(adv), device=model.device) * seq_log_prob\n",
|
| 396 |
-
" kl_loss = KL_COEFF * (seq_log_prob ** 2)\n",
|
| 397 |
-
" total_loss = total_loss + (pg_loss + kl_loss) / GROUP_SIZE\n",
|
| 398 |
-
"\n",
|
| 399 |
-
" total_loss.backward()\n",
|
| 400 |
-
" torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
|
| 401 |
-
" optimizer.step()\n",
|
| 402 |
-
" scheduler.step()\n",
|
| 403 |
-
"\n",
|
| 404 |
-
" avg_reward = float(rewards_arr.mean())\n",
|
| 405 |
-
" training_log.append((episode, task_id, avg_reward))\n",
|
| 406 |
-
"\n",
|
| 407 |
-
" if episode % LOG_EVERY == 0:\n",
|
| 408 |
-
" print(f'Episode {episode:3d}/{N_EPISODES} | task={task_id} | avg_reward={avg_reward:.3f} | loss={total_loss.item():.4f}')\n",
|
| 409 |
-
"\n",
|
| 410 |
-
"print('Training complete!')"
|
| 411 |
-
]
|
| 412 |
-
},
|
| 413 |
-
{
|
| 414 |
-
"cell_type": "code",
|
| 415 |
-
"execution_count": null,
|
| 416 |
-
"metadata": {},
|
| 417 |
-
"outputs": [],
|
| 418 |
-
"source": [
|
| 419 |
-
"model.eval()\n",
|
| 420 |
-
"\n",
|
| 421 |
-
"print('=== POST-TRAINING EVALUATION ===')\n",
|
| 422 |
-
"trained_scores = evaluate(n_seeds=3)\n",
|
| 423 |
-
"\n",
|
| 424 |
-
"print('\\n=== IMPROVEMENT SUMMARY ===')\n",
|
| 425 |
-
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 426 |
-
"print('-' * 38)\n",
|
| 427 |
-
"for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
|
| 428 |
-
" b = baseline_scores.get(key, 0)\n",
|
| 429 |
-
" a = trained_scores.get(key, 0)\n",
|
| 430 |
-
" d = a - b\n",
|
| 431 |
-
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')"
|
| 432 |
-
]
|
| 433 |
-
},
|
| 434 |
-
{
|
| 435 |
-
"cell_type": "code",
|
| 436 |
-
"execution_count": null,
|
| 437 |
-
"metadata": {},
|
| 438 |
-
"outputs": [],
|
| 439 |
-
"source": [
|
| 440 |
-
"import matplotlib.pyplot as plt\n",
|
| 441 |
-
"import numpy as np\n",
|
| 442 |
-
"\n",
|
| 443 |
-
"episodes = [x[0] for x in training_log]\n",
|
| 444 |
-
"task_ids = [x[1] for x in training_log]\n",
|
| 445 |
-
"ep_rewards = [x[2] for x in training_log]\n",
|
| 446 |
-
"\n",
|
| 447 |
-
"def moving_avg(data, window=5):\n",
|
| 448 |
-
" return np.convolve(data, np.ones(window)/window, mode='valid')\n",
|
| 449 |
-
"\n",
|
| 450 |
-
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 451 |
-
"fig.suptitle('Support Ticket Env - GRPO Training Results', fontsize=14, fontweight='bold')\n",
|
| 452 |
-
"\n",
|
| 453 |
-
"ax1 = axes[0]\n",
|
| 454 |
-
"colors = {1: '#3498db', 2: '#2ecc71', 3: '#e74c3c'}\n",
|
| 455 |
-
"for tid in [1, 2, 3]:\n",
|
| 456 |
-
" mask = [i for i, t in enumerate(task_ids) if t == tid]\n",
|
| 457 |
-
" if mask:\n",
|
| 458 |
-
" x = [episodes[i] for i in mask]\n",
|
| 459 |
-
" y = [ep_rewards[i] for i in mask]\n",
|
| 460 |
-
" ax1.scatter(x, y, alpha=0.3, color=colors[tid], s=15)\n",
|
| 461 |
-
" if len(y) >= 5:\n",
|
| 462 |
-
" smoothed = moving_avg(y)\n",
|
| 463 |
-
" ax1.plot(x[2:-2], smoothed, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
|
| 464 |
-
" else:\n",
|
| 465 |
-
" ax1.plot(x, y, color=colors[tid], linewidth=2, label=f'Task {tid}')\n",
|
| 466 |
-
"\n",
|
| 467 |
-
"ax1.set_xlabel('Episode')\n",
|
| 468 |
-
"ax1.set_ylabel('Avg Reward')\n",
|
| 469 |
-
"ax1.set_title('Training Reward per Episode')\n",
|
| 470 |
-
"ax1.legend()\n",
|
| 471 |
-
"ax1.grid(True, alpha=0.3)\n",
|
| 472 |
-
"ax1.set_ylim(-0.1, 1.1)\n",
|
| 473 |
-
"\n",
|
| 474 |
-
"ax2 = axes[1]\n",
|
| 475 |
-
"tasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\n",
|
| 476 |
-
"keys = ['task1', 'task2', 'task3', 'overall']\n",
|
| 477 |
-
"before_vals = [baseline_scores.get(k, 0) for k in keys]\n",
|
| 478 |
-
"after_vals = [trained_scores.get(k, 0) for k in keys]\n",
|
| 479 |
-
"\n",
|
| 480 |
-
"x = np.arange(len(tasks))\n",
|
| 481 |
-
"width = 0.35\n",
|
| 482 |
-
"\n",
|
| 483 |
-
"bars1 = ax2.bar(x - width/2, before_vals, width, label='Before Training', color='#95a5a6')\n",
|
| 484 |
-
"bars2 = ax2.bar(x + width/2, after_vals, width, label='After GRPO', color='#2ecc71')\n",
|
| 485 |
-
"\n",
|
| 486 |
-
"for bar in bars1:\n",
|
| 487 |
-
" ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
|
| 488 |
-
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
|
| 489 |
-
"for bar in bars2:\n",
|
| 490 |
-
" ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,\n",
|
| 491 |
-
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9,\n",
|
| 492 |
-
" fontweight='bold', color='#27ae60')\n",
|
| 493 |
-
"\n",
|
| 494 |
-
"ax2.set_xticks(x)\n",
|
| 495 |
-
"ax2.set_xticklabels(tasks)\n",
|
| 496 |
-
"ax2.set_ylabel('Score (0-1)')\n",
|
| 497 |
-
"ax2.set_title('Before vs After GRPO Training')\n",
|
| 498 |
-
"ax2.legend()\n",
|
| 499 |
-
"ax2.grid(True, alpha=0.3, axis='y')\n",
|
| 500 |
-
"ax2.set_ylim(0, 1.15)\n",
|
| 501 |
-
"\n",
|
| 502 |
-
"plt.tight_layout()\n",
|
| 503 |
-
"plt.savefig('/content/grpo_results.png', dpi=150, bbox_inches='tight')\n",
|
| 504 |
-
"plt.show()\n",
|
| 505 |
-
"print('Chart saved to /content/grpo_results.png')"
|
| 506 |
-
]
|
| 507 |
-
},
|
| 508 |
-
{
|
| 509 |
-
"cell_type": "code",
|
| 510 |
-
"execution_count": null,
|
| 511 |
-
"metadata": {},
|
| 512 |
-
"outputs": [],
|
| 513 |
-
"source": [
|
| 514 |
-
"import os\n",
|
| 515 |
-
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 516 |
-
"\n",
|
| 517 |
-
"model.save_pretrained(OUTPUT_DIR)\n",
|
| 518 |
-
"tokenizer.save_pretrained(OUTPUT_DIR)\n",
|
| 519 |
-
"print(f'Model saved to {OUTPUT_DIR}')\n",
|
| 520 |
-
"\n",
|
| 521 |
-
"try:\n",
|
| 522 |
-
" from huggingface_hub import HfApi\n",
|
| 523 |
-
" api = HfApi(token=HF_TOKEN)\n",
|
| 524 |
-
" api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
|
| 525 |
-
" api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 526 |
-
" api.upload_file(path_or_fileobj='/content/grpo_results.png', path_in_repo='grpo_results.png', repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 527 |
-
" print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
|
| 528 |
-
"except Exception as e:\n",
|
| 529 |
-
" print(f'Push failed: {e}')\n",
|
| 530 |
-
" print(f'Model is saved locally at {OUTPUT_DIR}')"
|
| 531 |
-
]
|
| 532 |
-
},
|
| 533 |
-
{
|
| 534 |
-
"cell_type": "code",
|
| 535 |
-
"execution_count": null,
|
| 536 |
-
"metadata": {},
|
| 537 |
-
"outputs": [],
|
| 538 |
-
"source": [
|
| 539 |
-
"from google.colab import files\n",
|
| 540 |
-
"files.download('/content/grpo_results.png')\n",
|
| 541 |
-
"\n",
|
| 542 |
-
"print('\\n' + '='*50)\n",
|
| 543 |
-
"print('FINAL TRAINING SUMMARY')\n",
|
| 544 |
-
"print('='*50)\n",
|
| 545 |
-
"print(f'Model: {MODEL_NAME}')\n",
|
| 546 |
-
"print(f'Algorithm: GRPO')\n",
|
| 547 |
-
"print(f'Episodes: {N_EPISODES}')\n",
|
| 548 |
-
"print(f'Env: {ENV_BASE_URL}')\n",
|
| 549 |
-
"print()\n",
|
| 550 |
-
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 551 |
-
"print('-' * 38)\n",
|
| 552 |
-
"for key, label in [(\"task1\",\"Task 1\"),(\"task2\",\"Task 2\"),(\"task3\",\"Task 3\"),(\"overall\",\"Overall\")]:\n",
|
| 553 |
-
" b = baseline_scores.get(key, 0)\n",
|
| 554 |
-
" a = trained_scores.get(key, 0)\n",
|
| 555 |
-
" d = a - b\n",
|
| 556 |
-
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {d:>+8.3f}')\n",
|
| 557 |
-
"print('='*50)\n",
|
| 558 |
-
"print(f'Model: https://huggingface.co/{HF_REPO_ID}')"
|
| 559 |
-
]
|
| 560 |
-
}
|
| 561 |
-
]
|
| 562 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|