Vighnesh commited on
Commit
5648ca2
·
1 Parent(s): 7bdf1e0

add training notebook

Browse files
Files changed (1) hide show
  1. train_grpo.ipynb +15 -787
train_grpo.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  "cell_type": "markdown",
17
  "metadata": {},
18
  "source": [
19
- "# Support Ticket Env GRPO Fine-Tuning\n",
20
  "**OpenEnv x Scalar Hackathon**\n",
21
  "\n",
22
  "Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using **real GRPO** (`trl.GRPOTrainer`) + LoRA (PEFT)\n",
@@ -26,7 +26,7 @@
26
  "- **Algorithm:** GRPO via `trl.GRPOTrainer` (proper clipped ratio + KL vs reference model)\n",
27
  "- **Environment:** https://algocore-support-ticket-env.hf.space\n",
28
  "- **Runtime:** ~30-45 min on Kaggle P100/T4 (or Colab)\n",
29
- "- **No Unsloth** standard HuggingFace transformers + PEFT"
30
  ]
31
  },
32
  {
@@ -35,10 +35,7 @@
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
38
- "# Install dependencies\n",
39
- "!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n",
40
- "!pip install -q bitsandbytes requests matplotlib\n",
41
- "print('Installation complete')"
42
  ]
43
  },
44
  {
@@ -47,52 +44,7 @@
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
50
- "import os\n",
51
- "\n",
52
- "# Load HF_TOKEN: Colab -> Kaggle -> env var\n",
53
- "HF_TOKEN = ''\n",
54
- "try:\n",
55
- " from google.colab import userdata\n",
56
- " HF_TOKEN = userdata.get('HF_TOKEN') or ''\n",
57
- "except Exception:\n",
58
- " pass\n",
59
- "\n",
60
- "if not HF_TOKEN:\n",
61
- " try:\n",
62
- " from kaggle_secrets import UserSecretsClient\n",
63
- " HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n",
64
- " except Exception:\n",
65
- " pass\n",
66
- "\n",
67
- "if not HF_TOKEN:\n",
68
- " HF_TOKEN = os.environ.get('HF_TOKEN', '')\n",
69
- "\n",
70
- "if not HF_TOKEN:\n",
71
- " raise ValueError('HF_TOKEN not found. Kaggle: Add-ons -> Secrets -> HF_TOKEN. Colab: key icon -> Secrets.')\n",
72
- "\n",
73
- "print('HF_TOKEN loaded OK')\n",
74
- "\n",
75
- "ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
76
- "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
77
- "# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n",
78
- "# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n",
79
- "# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\n",
80
- "HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
81
- "\n",
82
- "RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
83
- "OUTPUT_DIR = '/kaggle/working/support-ticket-grpo' if RUNTIME == 'kaggle' else '/content/support-ticket-grpo'\n",
84
- "RESULTS_IMG = '/kaggle/working/grpo_results.png' if RUNTIME == 'kaggle' else '/content/grpo_results.png'\n",
85
- "print(f'Runtime: {RUNTIME} | Output: {OUTPUT_DIR}')\n",
86
- "\n",
87
- "os.environ['HF_TOKEN'] = HF_TOKEN\n",
88
- "os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n",
89
- "\n",
90
- "import torch\n",
91
- "print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU — switch runtime!')\n",
92
- "if torch.cuda.is_available():\n",
93
- " print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), 'GB')\n",
94
- "print('Model:', MODEL_NAME)\n",
95
- "print('Env: ', ENV_BASE_URL)"
96
  ]
97
  },
98
  {
@@ -101,135 +53,7 @@
101
  "metadata": {},
102
  "outputs": [],
103
  "source": [
104
- "import requests, json, re, random\n",
105
- "from dataclasses import dataclass\n",
106
- "from typing import Optional\n",
107
- "\n",
108
- "TICKETS = [\n",
109
- " {'id':'T001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
110
- " {'id':'T002','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
111
- " {'id':'T003','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
112
- " {'id':'T004','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
113
- " {'id':'T005','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
114
- " {'id':'T006','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
115
- " {'id':'T007','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
116
- " {'id':'T008','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
117
- " {'id':'T009','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
118
- " {'id':'T010','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
119
- "]\n",
120
- "\n",
121
- "KEYWORD_REWARDS = {\n",
122
- " 'billing': ['charge','invoice','payment','billing','refund'],\n",
123
- " 'account': ['password','login','account','cancel','subscription'],\n",
124
- " 'technical': ['engineering','escalate','bug','crash','error'],\n",
125
- " 'refund': ['refund','return','credit','process'],\n",
126
- " 'general': ['hours','contact','phone','information','help'],\n",
127
- "}\n",
128
- "\n",
129
- "@dataclass\n",
130
- "class Obs:\n",
131
- " ticket_id: str\n",
132
- " ticket_text: str\n",
133
- " task_id: int\n",
134
- " current_category: Optional[str]\n",
135
- " resolved: bool\n",
136
- " step_count: int\n",
137
- " feedback: str\n",
138
- " score: float\n",
139
- " reward: float\n",
140
- " done: bool\n",
141
- "\n",
142
- "class LocalEnv:\n",
143
- " \"\"\"Local mirror of live HF Space — same reward logic, used as fallback.\"\"\"\n",
144
- " def reset(self, task_id=1, seed=42):\n",
145
- " rng = random.Random(seed)\n",
146
- " self.task_id = task_id\n",
147
- " self.ticket = rng.choice(TICKETS)\n",
148
- " self.classified = False\n",
149
- " self.step_count = 0\n",
150
- " return Obs(self.ticket['id'], self.ticket['text'], task_id,\n",
151
- " None, False, 0, 'New ticket. Take action.', 0.0, 0.0, False)\n",
152
- " def step(self, action):\n",
153
- " self.step_count += 1\n",
154
- " at = action.get('action_type', '')\n",
155
- " cat = action.get('category', '')\n",
156
- " reply = action.get('reply_text', '')\n",
157
- " reward = 0.0; done = False\n",
158
- " if self.task_id == 1:\n",
159
- " reward = 1.0 if cat == self.ticket['category'] else 0.0\n",
160
- " done = True\n",
161
- " elif self.task_id == 2:\n",
162
- " if not self.classified:\n",
163
- " reward = 0.3 if cat == self.ticket['category'] else 0.1\n",
164
- " self.classified = True\n",
165
- " else:\n",
166
- " reward = 1.0 if at == self.ticket['correct_action'] else 0.0\n",
167
- " done = True\n",
168
- " else:\n",
169
- " if not self.classified:\n",
170
- " reward = 0.2 if cat == self.ticket['category'] else 0.0\n",
171
- " self.classified = True\n",
172
- " else:\n",
173
- " action_score = 0.4 if at == self.ticket['correct_action'] else 0.0\n",
174
- " kws = KEYWORD_REWARDS.get(self.ticket['category'], [])\n",
175
- " reply_score = min(0.25, sum(0.05 for kw in kws if kw in reply.lower()))\n",
176
- " reward = action_score + reply_score\n",
177
- " done = True\n",
178
- " return Obs(self.ticket['id'], self.ticket['text'], self.task_id,\n",
179
- " self.ticket['category'] if self.classified else None,\n",
180
- " done, self.step_count, f'reward={reward:.2f}', reward, reward, done)\n",
181
- "\n",
182
- "class RemoteEnv:\n",
183
- " \"\"\"Live HF Space API.\"\"\"\n",
184
- " def __init__(self, base_url):\n",
185
- " self.base_url = base_url.rstrip('/')\n",
186
- " self.session = requests.Session()\n",
187
- " self.session.headers.update({'Content-Type': 'application/json'})\n",
188
- " def health(self):\n",
189
- " try:\n",
190
- " r = self.session.get(f'{self.base_url}/health', timeout=8)\n",
191
- " return r.status_code == 200\n",
192
- " except: return False\n",
193
- " def reset(self, task_id=1, seed=42):\n",
194
- " r = self.session.post(f'{self.base_url}/reset', json={'task_id': task_id, 'seed': seed}, timeout=15)\n",
195
- " r.raise_for_status()\n",
196
- " obs = r.json().get('observation', r.json())\n",
197
- " return self._parse_obs(obs)\n",
198
- " def step(self, action):\n",
199
- " r = self.session.post(f'{self.base_url}/step', json={'action': action}, timeout=15)\n",
200
- " r.raise_for_status()\n",
201
- " obs = r.json().get('observation', r.json())\n",
202
- " return self._parse_obs(obs)\n",
203
- " def _parse_obs(self, obs):\n",
204
- " # Safely coerce each field — avoids 'Field' object errors from dataclass defaults\n",
205
- " fields = Obs.__dataclass_fields__\n",
206
- " def safe(k, fallback):\n",
207
- " v = obs.get(k, fallback)\n",
208
- " if isinstance(v, type): return fallback # guard against dataclass Field objects\n",
209
- " return v\n",
210
- " return Obs(\n",
211
- " ticket_id=safe('ticket_id', ''),\n",
212
- " ticket_text=safe('ticket_text', ''),\n",
213
- " task_id=int(safe('task_id', 1)),\n",
214
- " current_category=safe('current_category', None),\n",
215
- " resolved=bool(safe('resolved', False)),\n",
216
- " step_count=int(safe('step_count', 0)),\n",
217
- " feedback=safe('feedback', ''),\n",
218
- " score=float(safe('score', 0.0)),\n",
219
- " reward=float(safe('reward', 0.0)),\n",
220
- " done=bool(safe('done', False)),\n",
221
- " )\n",
222
- "\n",
223
- "_remote = RemoteEnv(ENV_BASE_URL)\n",
224
- "if _remote.health():\n",
225
- " env_client = _remote\n",
226
- " print('Using LIVE environment:', ENV_BASE_URL)\n",
227
- "else:\n",
228
- " env_client = LocalEnv()\n",
229
- " print('Live API unreachable — using LOCAL mirror')\n",
230
- "\n",
231
- "obs = env_client.reset(task_id=1, seed=42)\n",
232
- "print(f'Ticket: {obs.ticket_id} — {obs.ticket_text[:60]}')"
233
  ]
234
  },
235
  {
@@ -238,39 +62,7 @@
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
241
- "import torch\n",
242
- "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
243
- "from peft import LoraConfig, TaskType\n",
244
- "\n",
245
- "MAX_SEQ_LENGTH = 512\n",
246
- "print(f'Loading {MODEL_NAME}...')\n",
247
- "\n",
248
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n",
249
- "tokenizer.pad_token = tokenizer.eos_token\n",
250
- "tokenizer.padding_side = 'left'\n",
251
- "\n",
252
- "# Qwen2.5-0.5B = ~1GB in fp16 — fits easily in 15.6GB T4, no quantization needed\n",
253
- "# bitsandbytes 4-bit + DataParallel + gradient checkpointing = CUDA illegal memory access\n",
254
- "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
255
- "model = AutoModelForCausalLM.from_pretrained(\n",
256
- " MODEL_NAME,\n",
257
- " dtype=torch.float16,\n",
258
- " device_map={'': 0},\n",
259
- " token=HF_TOKEN,\n",
260
- ")\n",
261
- "model.config.use_cache = False\n",
262
- "\n",
263
- "peft_config = LoraConfig(\n",
264
- " task_type=TaskType.CAUSAL_LM,\n",
265
- " r=16,\n",
266
- " lora_alpha=32,\n",
267
- " target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
268
- " lora_dropout=0.05,\n",
269
- " bias='none',\n",
270
- ")\n",
271
- "\n",
272
- "print('Model loaded — LoRA config ready (GRPOTrainer will apply PEFT internally)')\n",
273
- "print(f'Model params: {sum(p.numel() for p in model.parameters()):,}')"
274
  ]
275
  },
276
  {
@@ -279,64 +71,7 @@
279
  "metadata": {},
280
  "outputs": [],
281
  "source": [
282
- "SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n",
283
- "\n",
284
- "VALID action_type values: classify, reply, escalate, close\n",
285
- "VALID category values: billing, technical, account, general, refund\n",
286
- "\n",
287
- "For classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\n",
288
- "For reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\n",
289
- "For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n",
290
- "For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n",
291
- "\n",
292
- "RULES:\n",
293
- "- task_id=1: ALWAYS output action_type=classify first\n",
294
- "- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n",
295
- "- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n",
296
- "- technical/crash/error/bug tickets -> escalate\n",
297
- "- thank you/resolved tickets -> close\n",
298
- "- billing/account/refund/general -> reply\n",
299
- "- DO NOT use action_type=respond or action_type=resolve — those are INVALID'''\n",
300
- "\n",
301
- "def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n",
302
- " user_msg = json.dumps({\n",
303
- " 'ticket': ticket_text,\n",
304
- " 'task_id': task_id,\n",
305
- " 'current_category': current_category,\n",
306
- " 'feedback': feedback,\n",
307
- " 'step': step,\n",
308
- " })\n",
309
- " messages = [\n",
310
- " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
311
- " {'role': 'user', 'content': user_msg},\n",
312
- " ]\n",
313
- " return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
314
- "\n",
315
- "def parse_action(text):\n",
316
- " text = text.strip()\n",
317
- " # Strip markdown code blocks\n",
318
- " text = re.sub(r'^```(?:json)?\\s*', '', text)\n",
319
- " text = re.sub(r'\\s*```$', '', text.strip())\n",
320
- " try:\n",
321
- " return json.loads(text)\n",
322
- " except Exception:\n",
323
- " match = re.search(r'\\{[^{}]*\\}', text, re.DOTALL)\n",
324
- " if match:\n",
325
- " try: return json.loads(match.group())\n",
326
- " except: pass\n",
327
- " return {'action_type': 'classify', 'category': 'general'}\n",
328
- "\n",
329
- "def _safe_parse(completion):\n",
330
- " \"\"\"Always returns a dict, never a string.\"\"\"\n",
331
- " result = parse_action(completion) if isinstance(completion, str) else {}\n",
332
- " if not isinstance(result, dict):\n",
333
- " return {'action_type': '', 'category': '', 'reply_text': ''}\n",
334
- " return result\n",
335
- "\n",
336
- "print('Prompt builder OK')\n",
337
- "# Quick sanity check\n",
338
- "sample = make_prompt('I was charged twice', task_id=1)\n",
339
- "print('Sample prompt length (chars):', len(sample))"
340
  ]
341
  },
342
  {
@@ -345,146 +80,7 @@
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
348
- "# ─────────────────────────────────────────────────────────────────\n",
349
- "# Build LARGE dataset for GRPOTrainer\n",
350
- "# Strategy:\n",
351
- "# 1. Expanded ticket bank (50 tickets across all categories)\n",
352
- "# 2. All 3 task types x many seeds\n",
353
- "# 3. Multi-step contexts: step-0 (classify) AND step-1 (resolve)\n",
354
- "# 4. Paraphrase augmentation of ticket text\n",
355
- "# Target: ~500+ training samples\n",
356
- "# ─────────────────────────────────────────────────────────────────\n",
357
- "from datasets import Dataset\n",
358
- "\n",
359
- "MAX_STEPS = 6\n",
360
- "TASK_IDS = [1, 2, 3]\n",
361
- "\n",
362
- "# Large seed pool\n",
363
- "SEEDS = list(range(0, 200)) # 200 seeds\n",
364
- "\n",
365
- "# Expanded ticket bank — 50 tickets covering all categories\n",
366
- "ALL_TICKETS = [\n",
367
- " # billing (12)\n",
368
- " {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
369
- " {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
370
- " {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
371
- " {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n",
372
- " {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n",
373
- " {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n",
374
- " {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n",
375
- " {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n",
376
- " {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n",
377
- " {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n",
378
- " {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n",
379
- " {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n",
380
- " # account (10)\n",
381
- " {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
382
- " {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
383
- " {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n",
384
- " {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n",
385
- " {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n",
386
- " {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n",
387
- " {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n",
388
- " {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n",
389
- " {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n",
390
- " {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n",
391
- " # technical (10)\n",
392
- " {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
393
- " {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
394
- " {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n",
395
- " {'id':'T004','text':'Export to CSV is broken — it downloads an empty file.','category':'technical','correct_action':'escalate'},\n",
396
- " {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n",
397
- " {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n",
398
- " {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n",
399
- " {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n",
400
- " {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n",
401
- " {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n",
402
- " # refund (8)\n",
403
- " {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
404
- " {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n",
405
- " {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n",
406
- " {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n",
407
- " {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n",
408
- " {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n",
409
- " {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n",
410
- " {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n",
411
- " # general (10)\n",
412
- " {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
413
- " {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
414
- " {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n",
415
- " {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n",
416
- " {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n",
417
- " {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n",
418
- " {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n",
419
- " {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n",
420
- " {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n",
421
- " {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n",
422
- "]\n",
423
- "\n",
424
- "KEYWORD_REWARDS_FULL = {\n",
425
- " 'billing': ['charge','invoice','payment','billing','refund','receipt'],\n",
426
- " 'account': ['password','login','account','cancel','subscription','email'],\n",
427
- " 'technical': ['engineering','escalate','bug','crash','error','fix'],\n",
428
- " 'refund': ['refund','return','credit','process','reimburse'],\n",
429
- " 'general': ['hours','contact','phone','information','help','available'],\n",
430
- "}\n",
431
- "\n",
432
- "def build_grpo_dataset():\n",
433
- " rows = []\n",
434
- " rng = random.Random(2026)\n",
435
- "\n",
436
- " for task_id in TASK_IDS:\n",
437
- " for seed in SEEDS:\n",
438
- " # Pick a ticket deterministically from expanded bank\n",
439
- " ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
440
- "\n",
441
- " # --- Step 0: classify context ---\n",
442
- " prompt_step0 = make_prompt(\n",
443
- " ticket_text=ticket['text'],\n",
444
- " task_id=task_id,\n",
445
- " current_category=None,\n",
446
- " feedback='New ticket. Classify it first.',\n",
447
- " step=0,\n",
448
- " )\n",
449
- " rows.append({\n",
450
- " 'prompt': prompt_step0,\n",
451
- " 'ticket_text': ticket['text'],\n",
452
- " 'task_id': task_id,\n",
453
- " 'seed': seed,\n",
454
- " 'step': 0,\n",
455
- " })\n",
456
- "\n",
457
- " # --- Step 1: resolve context (tasks 2 & 3 only) ---\n",
458
- " if task_id in (2, 3):\n",
459
- " prompt_step1 = make_prompt(\n",
460
- " ticket_text=ticket['text'],\n",
461
- " task_id=task_id,\n",
462
- " current_category=ticket['category'],\n",
463
- " feedback=f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n",
464
- " step=1,\n",
465
- " )\n",
466
- " rows.append({\n",
467
- " 'prompt': prompt_step1,\n",
468
- " 'ticket_text': ticket['text'],\n",
469
- " 'task_id': task_id,\n",
470
- " 'seed': seed + 10000, # unique seed key for step-1\n",
471
- " 'step': 1,\n",
472
- " })\n",
473
- "\n",
474
- " # Shuffle so tasks/steps are interleaved during training\n",
475
- " rng.shuffle(rows)\n",
476
- " return Dataset.from_list(rows)\n",
477
- "\n",
478
- "grpo_dataset = build_grpo_dataset()\n",
479
- "print(f'Dataset built: {len(grpo_dataset)} samples')\n",
480
- "# breakdown\n",
481
- "from collections import Counter\n",
482
- "task_counts = Counter(grpo_dataset['task_id'])\n",
483
- "step_counts = Counter(grpo_dataset['step'])\n",
484
- "print(f' By task: {dict(task_counts)}')\n",
485
- "print(f' By step: {dict(step_counts)}')\n",
486
- "print('Sample prompt (first 300 chars):')\n",
487
- "print(grpo_dataset[0]['prompt'][:300])"
488
  ]
489
  },
490
  {
@@ -493,142 +89,7 @@
493
  "metadata": {},
494
  "outputs": [],
495
  "source": [
496
- "# ─────────────────────────────────────────────────────────────────\n",
497
- "# Reward functions — exact mirror of graders.py\n",
498
- "# grade_task1 / grade_task2 / grade_task3 / loop_penalty\n",
499
- "# ─────────────────────────────────────────────────────────────────\n",
500
- "\n",
501
- "# Partial-credit action pairs (from graders.py)\n",
502
- "_PARTIAL_CREDIT_PAIRS = {frozenset({'reply', 'escalate'})}\n",
503
- "\n",
504
- "# Keyword lists (from graders.py)\n",
505
- "_KEYWORD_REWARDS = {\n",
506
- " 'billing': ['refund', 'charge', 'invoice', 'payment', 'billing'],\n",
507
- " 'account': ['password', 'login', 'account', 'cancel', 'subscription'],\n",
508
- " 'technical': ['engineering', 'escalate', 'bug', 'crash', 'error', 'fix'],\n",
509
- " 'refund': ['refund', 'return', 'credit', 'process'],\n",
510
- " 'general': ['hours', 'contact', 'phone', 'information', 'help'],\n",
511
- "}\n",
512
- "\n",
513
- "def _reply_quality(reply_text, category):\n",
514
- " \"\"\"Exact copy of graders._reply_quality: 0.0–0.5 keyword score.\"\"\"\n",
515
- " if not reply_text: return 0.0\n",
516
- " hits = sum(1 for kw in _KEYWORD_REWARDS.get(category, []) if kw in reply_text.lower())\n",
517
- " return min(0.5, hits * 0.1)\n",
518
- "\n",
519
- "def _grade_task1(at, cat, correct_cat):\n",
520
- " \"\"\"Exact copy of graders.grade_task1.\"\"\"\n",
521
- " return 1.0 if (at == 'classify' and cat == correct_cat) else 0.0\n",
522
- "\n",
523
- "def _grade_task2(at, correct_action, step, cat, correct_cat):\n",
524
- " \"\"\"Exact copy of graders.grade_task2 + classify step.\"\"\"\n",
525
- " if step == 0:\n",
526
- " # classify step: partial credit for correct category\n",
527
- " if at == 'classify' and cat == correct_cat: return 0.3\n",
528
- " if at == 'classify': return 0.1\n",
529
- " return 0.0\n",
530
- " # action step\n",
531
- " if at == correct_action: return 1.0\n",
532
- " if frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS: return 0.5\n",
533
- " if at == 'close': return 0.0\n",
534
- " return 0.0\n",
535
- "\n",
536
- "def _grade_task3(at, cat, correct_cat, correct_action, reply, step, steps_taken=2, max_steps=5):\n",
537
- " \"\"\"Exact copy of graders.grade_task3.\"\"\"\n",
538
- " if step == 0:\n",
539
- " # classification step only\n",
540
- " return 0.20 if (at == 'classify' and cat == correct_cat) else 0.0\n",
541
- " # resolution step: 0.40 action + up to 0.50 reply + 0.15 efficiency\n",
542
- " score = 0.0\n",
543
- " classified_correctly = True # step-1 means step-0 already happened\n",
544
- " score += 0.20 # classification credit carried from step 0\n",
545
- " action_correct = (at == correct_action)\n",
546
- " action_partial = (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n",
547
- " if action_correct: score += 0.40\n",
548
- " elif action_partial: score += 0.20\n",
549
- " score += _reply_quality(reply, cat) # 0.0–0.5\n",
550
- " # efficiency bonus (assume 2 steps taken for step-1 samples)\n",
551
- " resolved = action_correct or action_partial\n",
552
- " if resolved and steps_taken <= max_steps:\n",
553
- " efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n",
554
- " score += 0.15 * efficiency\n",
555
- " return round(min(1.0, score), 4)\n",
556
- "\n",
557
- "def _loop_penalty(step_count, max_steps=10):\n",
558
- " \"\"\"Exact copy of graders.loop_penalty.\"\"\"\n",
559
- " return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n",
560
- "\n",
561
- "def _local_reward(completion, task_id, seed, step=0):\n",
562
- " \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n",
563
- " ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
564
- " action = _safe_parse(completion)\n",
565
- " if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
566
- " at = action.get('action_type', '')\n",
567
- " cat = action.get('category', '')\n",
568
- " reply = action.get('reply_text', '') or ''\n",
569
- " correct_cat = ticket['category']\n",
570
- " correct_action = ticket['correct_action']\n",
571
- "\n",
572
- " if task_id == 1:\n",
573
- " return _grade_task1(at, cat, correct_cat)\n",
574
- " elif task_id == 2:\n",
575
- " return _grade_task2(at, correct_action, step, cat, correct_cat)\n",
576
- " else: # task 3\n",
577
- " return _grade_task3(at, cat, correct_cat, correct_action, reply, step)\n",
578
- "\n",
579
- "def env_reward_fn(prompts, completions, **kwargs):\n",
580
- " \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n",
581
- " task_ids = kwargs.get('task_id', [1] * len(completions))\n",
582
- " seeds = kwargs.get('seed', [42] * len(completions))\n",
583
- " steps = kwargs.get('step', [0] * len(completions))\n",
584
- " rewards = []\n",
585
- " for i, completion in enumerate(completions):\n",
586
- " tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n",
587
- " seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n",
588
- " step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n",
589
- " actual_seed = seed % 10000 if seed >= 10000 else seed\n",
590
- " r = _local_reward(completion, tid, actual_seed, step)\n",
591
- " # apply loop penalty if step is high\n",
592
- " r += _loop_penalty(step)\n",
593
- " rewards.append(r)\n",
594
- " return rewards\n",
595
- "\n",
596
- "def format_reward_fn(prompts, completions, **kwargs):\n",
597
- " \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n",
598
- " rewards = []\n",
599
- " for completion in completions:\n",
600
- " action = _safe_parse(completion)\n",
601
- " if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
602
- " at = action.get('action_type', '')\n",
603
- " if at in ('classify', 'reply', 'escalate', 'close'):\n",
604
- " bonus = 0.15\n",
605
- " if at == 'classify' and action.get('category') in ('billing','technical','account','general','refund'):\n",
606
- " bonus = 0.20\n",
607
- " rewards.append(bonus)\n",
608
- " else:\n",
609
- " rewards.append(-0.20)\n",
610
- " return rewards\n",
611
- "\n",
612
- "# Print ticket map\n",
613
- "print('Reward functions synced to graders.py')\n",
614
- "print('Ticket map (seed % len):')\n",
615
- "for _i in range(6):\n",
616
- " _tt = ALL_TICKETS[_i]\n",
617
- " print(f' [{_i}] {_tt[\"id\"]} cat={_tt[\"category\"]} action={_tt[\"correct_action\"]}')\n",
618
- "\n",
619
- "# Sanity: seed=0->B001(billing,reply), seed=22->T001(technical,escalate)\n",
620
- "_t0 = ALL_TICKETS[0] # B001 billing reply\n",
621
- "_t22 = ALL_TICKETS[22] # T001 technical escalate\n",
622
- "r1 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 1, 0, 0)\n",
623
- "r2 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 2, 0, 0)\n",
624
- "r3 = _local_reward(json.dumps({'action_type':'escalate'}), 2, 0, 1)\n",
625
- "r4 = _local_reward(json.dumps({'action_type':_t22['correct_action'],'reply_text':'escalating this crash bug error to engineering team for a fix'}), 3, 22, 1)\n",
626
- "r5 = format_reward_fn(prompts=['x'], completions=[json.dumps({'action_type':'respond'})])[0]\n",
627
- "print(f'task1 correct classify: {r1} (expect 1.0)')\n",
628
- "print(f'task2 step0 correct classify: {r2} (expect 0.3)')\n",
629
- "print(f'task2 step1 partial escalate: {r3} (expect 0.5)')\n",
630
- "print(f'task3 step1 correct+keywords: {r4} (expect 0.87+)')\n",
631
- "print(f'bad format penalty: {r5} (expect -0.2)')\n"
632
  ]
633
  },
634
  {
@@ -637,65 +98,7 @@
637
  "metadata": {},
638
  "outputs": [],
639
  "source": [
640
- "# ─────────────────────────────────────────────────────────────────\n",
641
- "# Baseline evaluation BEFORE training\n",
642
- "# ─────────────────────────────────────────────────────────────────\n",
643
- "def quick_generate(prompt_text, max_new_tokens=120):\n",
644
- " model.eval()\n",
645
- " model.config.use_cache = True\n",
646
- " inputs = tokenizer(\n",
647
- " prompt_text, return_tensors='pt',\n",
648
- " truncation=True, max_length=MAX_SEQ_LENGTH\n",
649
- " ).to(DEVICE)\n",
650
- " with torch.no_grad():\n",
651
- " out = model.generate(\n",
652
- " **inputs,\n",
653
- " max_new_tokens=max_new_tokens,\n",
654
- " do_sample=False, # greedy for eval — deterministic\n",
655
- " pad_token_id=tokenizer.eos_token_id,\n",
656
- " use_cache=True,\n",
657
- " )\n",
658
- " new_tokens = out[0][inputs['input_ids'].shape[1]:]\n",
659
- " return tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
660
- "\n",
661
- "def evaluate(n_seeds=3, verbose=False):\n",
662
- " model.config.use_cache = True\n",
663
- " results = {}\n",
664
- " for task_id in [1, 2, 3]:\n",
665
- " task_rewards = []\n",
666
- " # Use LocalEnv for eval - live env is stateful/single-instance, causes 500s\n",
667
- " _eval_env = LocalEnv()\n",
668
- " EVAL_SEEDS = [42, 7, 123, 99, 13, 0, 1, 2, 5, 8]\n",
669
- " for seed in EVAL_SEEDS[:n_seeds]:\n",
670
- " obs = _eval_env.reset(task_id=task_id, seed=seed)\n",
671
- " total = 0.0\n",
672
- " done = False\n",
673
- " steps = 0\n",
674
- " for _ in range(MAX_STEPS):\n",
675
- " if done: break\n",
676
- " prompt = make_prompt(obs.ticket_text, obs.task_id, obs.current_category, obs.feedback, obs.step_count)\n",
677
- " completion = quick_generate(prompt)\n",
678
- " action = parse_action(completion)\n",
679
- " if verbose: print(f' T{task_id} s{seed} step{steps+1}: {action}')\n",
680
- " try:\n",
681
- " obs = _eval_env.step(action)\n",
682
- " total += float(obs.reward or 0.0)\n",
683
- " done = obs.done\n",
684
- " except Exception as e:\n",
685
- " if verbose: print(f' [err] {e}')\n",
686
- " done = True\n",
687
- " steps += 1\n",
688
- " norm = round(max(0.0, min(1.0, total / max(steps, 1))), 3)\n",
689
- " task_rewards.append(norm)\n",
690
- " avg = round(sum(task_rewards) / len(task_rewards), 3)\n",
691
- " results[f'task{task_id}'] = avg\n",
692
- " print(f' Task {task_id}: {avg:.3f}')\n",
693
- " results['overall'] = round(sum(results[k] for k in ['task1','task2','task3']) / 3, 3)\n",
694
- " print(f' Overall: {results[\"overall\"]:.3f}')\n",
695
- " return results\n",
696
- "\n",
697
- "print('=== BASELINE (before training) ===')\n",
698
- "baseline_scores = evaluate(n_seeds=3, verbose=True)"
699
  ]
700
  },
701
  {
@@ -704,74 +107,7 @@
704
  "metadata": {},
705
  "outputs": [],
706
  "source": [
707
- "# ─────────────────────────────────────────────────────────────────\n",
708
- "# GRPO Training with trl.GRPOTrainer\n",
709
- "# This is REAL GRPO:\n",
710
- "# - Maintains a frozen reference model for KL divergence\n",
711
- "# - Clips probability ratios (PPO-style)\n",
712
- "# - Groups completions, normalises advantages within group\n",
713
- "# ─────────────────────────────────────────────────────────────────\n",
714
- "from trl import GRPOConfig, GRPOTrainer\n",
715
- "\n",
716
- "grpo_config = GRPOConfig(\n",
717
- " # Output\n",
718
- " output_dir=OUTPUT_DIR,\n",
719
- "\n",
720
- " # Training scale\n",
721
- " num_train_epochs=3, # 3 passes over ~500 samples = ~1500 gradient steps\n",
722
- " per_device_train_batch_size=2,\n",
723
- " gradient_accumulation_steps=4,\n",
724
- "\n",
725
- " # GRPO-specific\n",
726
- " num_generations=4, # group size G — completions sampled per prompt\n",
727
- " max_prompt_length=384,\n",
728
- " max_completion_length=128,\n",
729
- " temperature=0.9,\n",
730
- " beta=0.04, # KL coefficient against reference model\n",
731
- "\n",
732
- " # Optimiser\n",
733
- " learning_rate=5e-5,\n",
734
- " lr_scheduler_type='cosine',\n",
735
- " warmup_ratio=0.1,\n",
736
- " weight_decay=0.01,\n",
737
- " max_grad_norm=1.0,\n",
738
- " optim='adamw_torch',\n",
739
- "\n",
740
- " # Logging\n",
741
- " logging_steps=5,\n",
742
- " save_strategy='no',\n",
743
- " report_to='none',\n",
744
- "\n",
745
- " # Memory — model loaded in fp16 natively, no quantization wrapper\n",
746
- " bf16=False,\n",
747
- " fp16=True, # keeps optimizer in fp16 to match model dtype\n",
748
- " dataloader_pin_memory=False,\n",
749
- " remove_unused_columns=False, # keep task_id, seed, step columns for reward fn\n",
750
- " ddp_find_unused_parameters=False, # disable DataParallel — single GPU only\n",
751
- ")\n",
752
- "\n",
753
- "trainer = GRPOTrainer(\n",
754
- " model=model,\n",
755
- " args=grpo_config,\n",
756
- " train_dataset=grpo_dataset,\n",
757
- " reward_funcs=[env_reward_fn, format_reward_fn], # multiple reward signals\n",
758
- " peft_config=peft_config,\n",
759
- " processing_class=tokenizer,\n",
760
- ")\n",
761
- "\n",
762
- "print('GRPOTrainer initialised')\n",
763
- "print(f'Dataset size: {len(grpo_dataset)} samples')\n",
764
- "print(f'Group size (G): {grpo_config.num_generations}')\n",
765
- "print(f'KL beta: {grpo_config.beta}')\n",
766
- "print(f'Max completion: {grpo_config.max_completion_length} tokens')\n",
767
- "print('Starting GRPO training...')\n",
768
- "print('=' * 60)\n",
769
- "\n",
770
- "train_result = trainer.train()\n",
771
- "\n",
772
- "print('=' * 60)\n",
773
- "print('Training complete!')\n",
774
- "print(f'Loss: {train_result.training_loss:.4f}')"
775
  ]
776
  },
777
  {
@@ -780,22 +116,7 @@
780
  "metadata": {},
781
  "outputs": [],
782
  "source": [
783
- "# ─────────────────────────────────────────────────────────────────\n",
784
- "# Post-training evaluation\n",
785
- "# ─────────────────────────────────────────────────────────────────\n",
786
- "model.config.use_cache = True\n",
787
- "model.eval()\n",
788
- "\n",
789
- "print('=== POST-TRAINING EVALUATION ===')\n",
790
- "trained_scores = evaluate(n_seeds=3)\n",
791
- "\n",
792
- "print('\\n=== IMPROVEMENT SUMMARY ===')\n",
793
- "print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
794
- "print('-' * 38)\n",
795
- "for key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n",
796
- " b = baseline_scores.get(key, 0)\n",
797
- " a = trained_scores.get(key, 0)\n",
798
- " print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')"
799
  ]
800
  },
801
  {
@@ -804,52 +125,7 @@
804
  "metadata": {},
805
  "outputs": [],
806
  "source": [
807
- "import matplotlib.pyplot as plt\n",
808
- "import numpy as np\n",
809
- "\n",
810
- "# Extract training reward history from trainer logs\n",
811
- "log_history = trainer.state.log_history\n",
812
- "train_steps = [l['step'] for l in log_history if 'loss' in l]\n",
813
- "train_losses = [l['loss'] for l in log_history if 'loss' in l]\n",
814
- "reward_steps = [l['step'] for l in log_history if 'reward' in str(l)]\n",
815
- "\n",
816
- "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
817
- "fig.suptitle('Support Ticket Env — GRPO Training Results', fontsize=14, fontweight='bold')\n",
818
- "\n",
819
- "# Left: training loss\n",
820
- "ax1 = axes[0]\n",
821
- "if train_steps:\n",
822
- " ax1.plot(train_steps, train_losses, color='#3498db', linewidth=2)\n",
823
- " ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
824
- " ax1.set_title('GRPO Training Loss')\n",
825
- " ax1.grid(True, alpha=0.3)\n",
826
- "else:\n",
827
- " ax1.text(0.5, 0.5, 'No loss logs', ha='center', va='center', transform=ax1.transAxes)\n",
828
- "\n",
829
- "# Right: before vs after bar chart\n",
830
- "ax2 = axes[1]\n",
831
- "tasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\n",
832
- "keys = ['task1', 'task2', 'task3', 'overall']\n",
833
- "bv = [baseline_scores.get(k, 0) for k in keys]\n",
834
- "av = [trained_scores.get(k, 0) for k in keys]\n",
835
- "x = np.arange(len(tasks)); w = 0.35\n",
836
- "b1 = ax2.bar(x - w/2, bv, w, label='Before GRPO', color='#95a5a6')\n",
837
- "b2 = ax2.bar(x + w/2, av, w, label='After GRPO', color='#2ecc71')\n",
838
- "for bar in b1:\n",
839
- " ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n",
840
- " f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
841
- "for bar in b2:\n",
842
- " ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n",
843
- " f'{bar.get_height():.2f}', ha='center', va='bottom',\n",
844
- " fontsize=9, fontweight='bold', color='#27ae60')\n",
845
- "ax2.set_xticks(x); ax2.set_xticklabels(tasks)\n",
846
- "ax2.set_ylabel('Score (0–1)'); ax2.set_title('Before vs After GRPO')\n",
847
- "ax2.legend(); ax2.grid(True, alpha=0.3, axis='y'); ax2.set_ylim(0, 1.15)\n",
848
- "\n",
849
- "plt.tight_layout()\n",
850
- "plt.savefig(RESULTS_IMG, dpi=150, bbox_inches='tight')\n",
851
- "plt.show()\n",
852
- "print(f'Chart saved to {RESULTS_IMG}')"
853
  ]
854
  },
855
  {
@@ -858,27 +134,7 @@
858
  "metadata": {},
859
  "outputs": [],
860
  "source": [
861
- "import os\n",
862
- "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
863
- "trainer.save_model(OUTPUT_DIR)\n",
864
- "tokenizer.save_pretrained(OUTPUT_DIR)\n",
865
- "print(f'Model saved to {OUTPUT_DIR}')\n",
866
- "\n",
867
- "try:\n",
868
- " from huggingface_hub import HfApi\n",
869
- " api = HfApi(token=HF_TOKEN)\n",
870
- " api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
871
- " api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n",
872
- " api.upload_file(\n",
873
- " path_or_fileobj=RESULTS_IMG,\n",
874
- " path_in_repo='grpo_results.png',\n",
875
- " repo_id=HF_REPO_ID,\n",
876
- " repo_type='model'\n",
877
- " )\n",
878
- " print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
879
- "except Exception as e:\n",
880
- " print(f'HF push failed: {e}')\n",
881
- " print(f'Model saved locally at {OUTPUT_DIR}')"
882
  ]
883
  },
884
  {
@@ -887,35 +143,7 @@
887
  "metadata": {},
888
  "outputs": [],
889
  "source": [
890
- "# Download chart (Colab only Kaggle: Output tab)\n",
891
- "if RUNTIME == 'colab':\n",
892
- " try:\n",
893
- " from google.colab import files\n",
894
- " files.download(RESULTS_IMG)\n",
895
- " except Exception as e:\n",
896
- " print(f'Download skipped: {e}')\n",
897
- "else:\n",
898
- " print(f'Kaggle: chart in Output tab -> {RESULTS_IMG}')\n",
899
- "\n",
900
- "print('\\n' + '='*55)\n",
901
- "print('FINAL TRAINING SUMMARY')\n",
902
- "print('='*55)\n",
903
- "print(f'Model: {MODEL_NAME}')\n",
904
- "print(f'Algorithm: GRPO (trl.GRPOTrainer) + LoRA')\n",
905
- "print(f'Group size G: {grpo_config.num_generations}')\n",
906
- "print(f'KL beta: {grpo_config.beta}')\n",
907
- "print(f'Dataset size: {len(grpo_dataset)} prompts')\n",
908
- "print(f'Env: {ENV_BASE_URL}')\n",
909
- "print(f'Final loss: {train_result.training_loss:.4f}')\n",
910
- "print()\n",
911
- "print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
912
- "print('-' * 42)\n",
913
- "for key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n",
914
- " b = baseline_scores.get(key, 0)\n",
915
- " a = trained_scores.get(key, 0)\n",
916
- " print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')\n",
917
- "print('='*55)\n",
918
- "print(f'HF Model: https://huggingface.co/{HF_REPO_ID}')"
919
  ]
920
  }
921
  ]
 
16
  "cell_type": "markdown",
17
  "metadata": {},
18
  "source": [
19
+ "# Support Ticket Env \u2014 GRPO Fine-Tuning\n",
20
  "**OpenEnv x Scalar Hackathon**\n",
21
  "\n",
22
  "Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using **real GRPO** (`trl.GRPOTrainer`) + LoRA (PEFT)\n",
 
26
  "- **Algorithm:** GRPO via `trl.GRPOTrainer` (proper clipped ratio + KL vs reference model)\n",
27
  "- **Environment:** https://algocore-support-ticket-env.hf.space\n",
28
  "- **Runtime:** ~30-45 min on Kaggle P100/T4 (or Colab)\n",
29
+ "- **No Unsloth** \u2014 standard HuggingFace transformers + PEFT"
30
  ]
31
  },
32
  {
 
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
38
+ "# Install dependencies\n!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n!pip install -q bitsandbytes requests matplotlib wandb\nprint('Installation complete')"
 
 
 
39
  ]
40
  },
41
  {
 
44
  "metadata": {},
45
  "outputs": [],
46
  "source": [
47
+ "import os\n\n# Kaggle dataset path \u2014 graders.py, tickets.py, support_environment.py\nimport sys\nsys.path.insert(0, '/kaggle/input/support-ticket/')\n\n# Load HF_TOKEN: Colab -> Kaggle -> env var\nHF_TOKEN = ''\ntry:\n from google.colab import userdata\n HF_TOKEN = userdata.get('HF_TOKEN') or ''\nexcept Exception:\n pass\n\nif not HF_TOKEN:\n try:\n from kaggle_secrets import UserSecretsClient\n HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n except Exception:\n pass\n\nif not HF_TOKEN:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError('HF_TOKEN not found. Kaggle: Add-ons -> Secrets -> HF_TOKEN. Colab: key icon -> Secrets.')\n\nprint('HF_TOKEN loaded OK')\n\nENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\nMODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\nHF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n\nRUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\nOUTPUT_DIR = '/kaggle/working/support-ticket-grpo' if RUNTIME == 'kaggle' else '/content/support-ticket-grpo'\nRESULTS_IMG = '/kaggle/working/grpo_results.png' if RUNTIME == 'kaggle' else '/content/grpo_results.png'\nprint(f'Runtime: {RUNTIME} | Output: {OUTPUT_DIR}')\n\nos.environ['HF_TOKEN'] = HF_TOKEN\nos.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n\nimport torch\nprint('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU \u2014 switch runtime!')\nif torch.cuda.is_available():\n print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), 'GB')\nprint('Model:', MODEL_NAME)\nprint('Env: ', ENV_BASE_URL)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ]
49
  },
50
  {
 
53
  "metadata": {},
54
  "outputs": [],
55
  "source": [
56
+ "import requests, json, re, random\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nTICKETS = [\n {'id':'T001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n {'id':'T002','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n {'id':'T003','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n {'id':'T004','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n {'id':'T005','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n {'id':'T006','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n {'id':'T007','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n {'id':'T008','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n {'id':'T009','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n {'id':'T010','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n]\n\nKEYWORD_REWARDS = {\n 'billing': ['charge','invoice','payment','billing','refund'],\n 'account': ['password','login','account','cancel','subscription'],\n 'technical': ['engineering','escalate','bug','crash','error'],\n 'refund': ['refund','return','credit','process'],\n 'general': ['hours','contact','phone','information','help'],\n}\n\n@dataclass\nclass Obs:\n ticket_id: str\n ticket_text: str\n task_id: int\n current_category: Optional[str]\n resolved: bool\n step_count: int\n feedback: str\n score: float\n reward: float\n done: bool\n\nclass LocalEnv:\n \"\"\"Local mirror of live HF Space \u2014 same reward logic, used as fallback.\"\"\"\n def reset(self, task_id=1, seed=42):\n rng = random.Random(seed)\n self.task_id = task_id\n self.ticket = rng.choice(TICKETS)\n self.classified = False\n self.step_count = 0\n return Obs(self.ticket['id'], self.ticket['text'], task_id,\n None, False, 0, 'New ticket. Take action.', 0.0, 0.0, False)\n def step(self, action):\n self.step_count += 1\n at = action.get('action_type', '')\n cat = action.get('category', '')\n reply = action.get('reply_text', '')\n reward = 0.0; done = False\n if self.task_id == 1:\n reward = 1.0 if cat == self.ticket['category'] else 0.0\n done = True\n elif self.task_id == 2:\n if not self.classified:\n reward = 0.3 if cat == self.ticket['category'] else 0.1\n self.classified = True\n else:\n reward = 1.0 if at == self.ticket['correct_action'] else 0.0\n done = True\n else:\n if not self.classified:\n reward = 0.2 if cat == self.ticket['category'] else 0.0\n self.classified = True\n else:\n action_score = 0.4 if at == self.ticket['correct_action'] else 0.0\n kws = KEYWORD_REWARDS.get(self.ticket['category'], [])\n reply_score = min(0.25, sum(0.05 for kw in kws if kw in reply.lower()))\n reward = action_score + reply_score\n done = True\n return Obs(self.ticket['id'], self.ticket['text'], self.task_id,\n self.ticket['category'] if self.classified else None,\n done, self.step_count, f'reward={reward:.2f}', reward, reward, done)\n\nclass RemoteEnv:\n \"\"\"Live HF Space API.\"\"\"\n def __init__(self, base_url):\n self.base_url = base_url.rstrip('/')\n self.session = requests.Session()\n self.session.headers.update({'Content-Type': 'application/json'})\n def health(self):\n try:\n r = self.session.get(f'{self.base_url}/health', timeout=8)\n return r.status_code == 200\n except: return False\n def reset(self, task_id=1, seed=42):\n r = self.session.post(f'{self.base_url}/reset', json={'task_id': task_id, 'seed': seed}, timeout=15)\n r.raise_for_status()\n obs = r.json().get('observation', r.json())\n return self._parse_obs(obs)\n def step(self, action):\n r = self.session.post(f'{self.base_url}/step', json={'action': action}, timeout=15)\n r.raise_for_status()\n obs = r.json().get('observation', r.json())\n return self._parse_obs(obs)\n def _parse_obs(self, obs):\n # Safely coerce each field \u2014 avoids 'Field' object errors from dataclass defaults\n fields = Obs.__dataclass_fields__\n def safe(k, fallback):\n v = obs.get(k, fallback)\n if isinstance(v, type): return fallback # guard against dataclass Field objects\n return v\n return Obs(\n ticket_id=safe('ticket_id', ''),\n ticket_text=safe('ticket_text', ''),\n task_id=int(safe('task_id', 1)),\n current_category=safe('current_category', None),\n resolved=bool(safe('resolved', False)),\n step_count=int(safe('step_count', 0)),\n feedback=safe('feedback', ''),\n score=float(safe('score', 0.0)),\n reward=float(safe('reward', 0.0)),\n done=bool(safe('done', False)),\n )\n\n_remote = RemoteEnv(ENV_BASE_URL)\nif _remote.health():\n env_client = _remote\n print('Using LIVE environment:', ENV_BASE_URL)\nelse:\n env_client = LocalEnv()\n print('Live API unreachable \u2014 using LOCAL mirror')\n\nobs = env_client.reset(task_id=1, seed=42)\nprint(f'Ticket: {obs.ticket_id} \u2014 {obs.ticket_text[:60]}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ]
58
  },
59
  {
 
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
+ "import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import LoraConfig, TaskType\n\nMAX_SEQ_LENGTH = 512\nprint(f'Loading {MODEL_NAME}...')\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\ntokenizer.pad_token = tokenizer.eos_token\ntokenizer.padding_side = 'left'\n\n# Qwen2.5-0.5B = ~1GB in fp16 \u2014 fits easily in 15.6GB T4, no quantization needed\n# bitsandbytes 4-bit + DataParallel + gradient checkpointing = CUDA illegal memory access\nDEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\nmodel = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n dtype=torch.float16,\n device_map={'': 0},\n token=HF_TOKEN,\n)\nmodel.config.use_cache = False\n\npeft_config = LoraConfig(\n task_type=TaskType.CAUSAL_LM,\n r=16,\n lora_alpha=32,\n target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n lora_dropout=0.05,\n bias='none',\n)\n\nprint('Model loaded \u2014 LoRA config ready (GRPOTrainer will apply PEFT internally)')\nprint(f'Model params: {sum(p.numel() for p in model.parameters()):,}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ]
67
  },
68
  {
 
71
  "metadata": {},
72
  "outputs": [],
73
  "source": [
74
+ "SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n\nVALID action_type values: classify, reply, escalate, close\nVALID category values: billing, technical, account, general, refund\n\nFor classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\nFor reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\nFor escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\nFor close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n\nRULES:\n- task_id=1: ALWAYS output action_type=classify first\n- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n- technical/crash/error/bug tickets -> escalate\n- thank you/resolved tickets -> close\n- billing/account/refund/general -> reply\n- DO NOT use action_type=respond or action_type=resolve \u2014 those are INVALID'''\n\ndef make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n user_msg = json.dumps({\n 'ticket': ticket_text,\n 'task_id': task_id,\n 'current_category': current_category,\n 'feedback': feedback,\n 'step': step,\n })\n messages = [\n {'role': 'system', 'content': SYSTEM_PROMPT},\n {'role': 'user', 'content': user_msg},\n ]\n return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n\ndef parse_action(text):\n text = text.strip()\n # Strip markdown code blocks\n text = re.sub(r'^```(?:json)?\\s*', '', text)\n text = re.sub(r'\\s*```$', '', text.strip())\n try:\n return json.loads(text)\n except Exception:\n match = re.search(r'\\{[^{}]*\\}', text, re.DOTALL)\n if match:\n try: return json.loads(match.group())\n except: pass\n return {'action_type': 'classify', 'category': 'general'}\n\ndef _safe_parse(completion):\n \"\"\"Always returns a dict, never a string.\"\"\"\n result = parse_action(completion) if isinstance(completion, str) else {}\n if not isinstance(result, dict):\n return {'action_type': '', 'category': '', 'reply_text': ''}\n return result\n\nprint('Prompt builder OK')\n# Quick sanity check\nsample = make_prompt('I was charged twice', task_id=1)\nprint('Sample prompt length (chars):', len(sample))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  ]
76
  },
77
  {
 
80
  "metadata": {},
81
  "outputs": [],
82
  "source": [
83
+ "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Build LARGE dataset for GRPOTrainer\n# Strategy:\n# 1. Expanded ticket bank (50 tickets across all categories)\n# 2. All 3 task types x many seeds\n# 3. Multi-step contexts: step-0 (classify) AND step-1 (resolve)\n# 4. Paraphrase augmentation of ticket text\n# Target: ~500+ training samples\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nfrom datasets import Dataset\n\nMAX_STEPS = 6\nTASK_IDS = [1, 2, 3]\n\n# Large seed pool\nSEEDS = list(range(0, 200)) # 200 seeds\n\n# Expanded ticket bank \u2014 50 tickets covering all categories\nALL_TICKETS = [\n # billing (12)\n {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply','resolution_hint':'apologize for duplicate charge and initiate refund to original payment method within 3-5 days'},\n {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate','resolution_hint':'escalate potential unauthorized plan charge to billing team for investigation and correction'},\n {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply','resolution_hint':'generate itemised invoice with line-item breakdown and email to customer accounting address'},\n {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply','resolution_hint':'verify trial end date in billing system and issue refund for premature charge before expiry'},\n {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply','resolution_hint':'confirm plan switch date in system and issue prorated credit for overcharge at old rate'},\n {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate','resolution_hint':'escalate triple charge incident to billing fraud team and freeze further charges pending review'},\n {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply','resolution_hint':'verify cancellation timestamp confirm post-cancel charge and process refund for final month'},\n {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply','resolution_hint':'locate last successful payment record and email PDF receipt to customer registered address'},\n {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply','resolution_hint':'identify currency mismatch at signup and issue credit note for exchange rate difference'},\n {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply','resolution_hint':'locate discount code application log verify failure reason and apply credit to next invoice'},\n {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply','resolution_hint':'update billing address in account settings and reissue corrected invoice for their records'},\n {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate','resolution_hint':'escalate ghost charge to payments team attach failed payment notification as evidence for review'},\n # account (10)\n {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply','resolution_hint':'check spam folder verify registered email address resend password reset link account locked'},\n {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply','resolution_hint':'navigate account settings subscription tab locate cancel option confirm cancellation effective date'},\n {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply','resolution_hint':'verify identity via security question update email address send confirmation to both old and new'},\n {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply','resolution_hint':'unlock account after failed login attempts verify identity via backup code or support email'},\n {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply','resolution_hint':'check account deletion grace period restore from backup if within 30 days confirm data intact'},\n {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply','resolution_hint':'verify ownership of both accounts initiate transfer request update billing and login credentials'},\n {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply','resolution_hint':'verify 2FA device registration resync authenticator app or issue backup recovery codes immediately'},\n {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply','resolution_hint':'provide GDPR data export link in account privacy settings confirm 30-day download window'},\n {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate','resolution_hint':'escalate unauthorized username change to security team flag for account compromise investigation'},\n {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply','resolution_hint':'confirm current free plan limits explain premium features and provide upgrade link with pricing'},\n # technical (10)\n {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate','resolution_hint':'escalate to engineering with file size limit crash reproduction steps and device logs attached'},\n {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate','resolution_hint':'escalate API 500 errors to on-call engineering with timestamps error codes and affected endpoints'},\n {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate','resolution_hint':'escalate blank dashboard to engineering with browser version last working date and console errors'},\n {'id':'T004','text':'Export to CSV is broken \u2014 it downloads an empty file.','category':'technical','correct_action':'escalate','resolution_hint':'escalate empty CSV export bug to engineering with sample dataset and export configuration used'},\n {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate','resolution_hint':'escalate notification delivery failure to infrastructure team check email provider and push config'},\n {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate','resolution_hint':'escalate iOS 17 freeze to mobile engineering with device model OS version and crash report'},\n {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate','resolution_hint':'escalate search returning no results to engineering with query examples and index rebuild request'},\n {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate','resolution_hint':'escalate device sync failure to backend team with affected device IDs and last sync timestamp'},\n {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate','resolution_hint':'escalate webhook timeout to integrations team with endpoint URL payload size and retry logs'},\n {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate','resolution_hint':'escalate browser extension JavaScript error to frontend team with browser version and error stack'},\n # refund (8)\n {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply','resolution_hint':'confirm zero usage this billing period process full refund within 5-7 business days to original payment method'},\n {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply','resolution_hint':'verify double charge in payment gateway logs process refund for duplicate amount to card on file'},\n {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply','resolution_hint':'review product description versus delivered functionality confirm mismatch and process refund'},\n {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply','resolution_hint':'verify cancellation date within refund window locate delayed refund in processor and escalate'},\n {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply','resolution_hint':'calculate unused months on annual plan process prorated refund for remaining subscription period'},\n {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate','resolution_hint':'escalate undelivered promised refund to billing manager attach original support agent transcript'},\n {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate','resolution_hint':'escalate fraudulent charge to payments fraud team freeze account initiate chargeback process'},\n {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply','resolution_hint':'explain refund timeline 5-7 business days for credit card 1-3 days for original payment method'},\n # general (10)\n {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply','resolution_hint':'provide support hours 9am-6pm weekdays toll free number and link to contact page for phone'},\n {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close','resolution_hint':'acknowledge resolution thank customer for positive feedback and close ticket with satisfaction note'},\n {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply','resolution_hint':'confirm student discount eligibility criteria provide non-profit pricing page and application form'},\n {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply','resolution_hint':'share direct links to terms of service privacy policy and data processing agreement documents'},\n {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply','resolution_hint':'confirm service availability in Brazil note any regional restrictions and provide local pricing'},\n {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply','resolution_hint':'confirm commercial use rights under current plan outline enterprise licensing for larger usage'},\n {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close','resolution_hint':'acknowledge quick resolution compliment note feedback for team performance review close ticket'},\n {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply','resolution_hint':'provide affiliate program signup link commission structure and referral tracking dashboard access'},\n {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply','resolution_hint':'list supported third-party integrations provide API docs link and Zapier connector instructions'},\n {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close','resolution_hint':'acknowledge positive team feedback forward compliment to product team and close with gratitude'},\n]\n\nKEYWORD_REWARDS_FULL = {\n 'billing': ['charge','invoice','payment','billing','refund','receipt'],\n 'account': ['password','login','account','cancel','subscription','email'],\n 'technical': ['engineering','escalate','bug','crash','error','fix'],\n 'refund': ['refund','return','credit','process','reimburse'],\n 'general': ['hours','contact','phone','information','help','available'],\n}\n\ndef build_grpo_dataset():\n rows = []\n rng = random.Random(2026)\n\n for task_id in TASK_IDS:\n for seed in SEEDS:\n # Pick a ticket deterministically from expanded bank\n ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n\n # --- Step 0: classify context ---\n prompt_step0 = make_prompt(\n ticket_text=ticket['text'],\n task_id=task_id,\n current_category=None,\n feedback='New ticket. Classify it first.',\n step=0,\n )\n rows.append({\n 'prompt': prompt_step0,\n 'ticket_text': ticket['text'],\n 'task_id': task_id,\n 'seed': seed,\n 'step': 0,\n })\n\n # --- Step 1: resolve context (tasks 2 & 3 only) ---\n if task_id in (2, 3):\n prompt_step1 = make_prompt(\n ticket_text=ticket['text'],\n task_id=task_id,\n current_category=ticket['category'],\n feedback=f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n step=1,\n )\n rows.append({\n 'prompt': prompt_step1,\n 'ticket_text': ticket['text'],\n 'task_id': task_id,\n 'seed': seed + 10000, # unique seed key for step-1\n 'step': 1,\n })\n\n # Shuffle so tasks/steps are interleaved during training\n rng.shuffle(rows)\n return Dataset.from_list(rows)\n\ngrpo_dataset = build_grpo_dataset()\nprint(f'Dataset built: {len(grpo_dataset)} samples')\n# breakdown\nfrom collections import Counter\ntask_counts = Counter(grpo_dataset['task_id'])\nstep_counts = Counter(grpo_dataset['step'])\nprint(f' By task: {dict(task_counts)}')\nprint(f' By step: {dict(step_counts)}')\nprint('Sample prompt (first 300 chars):')\nprint(grpo_dataset[0]['prompt'][:300])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
 
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
92
+ "# -----------------------------------------------------------------\n# Reward functions \u2014 synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE \u2014 keep in sync with graders.py manually.\n# FALLBACK ONLY \u2014 if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords \u2014 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST \u2014 runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n\n\ndef _local_reward(completion, task_id, seed, step=0, cls_credit=0.0):\n \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n cat = action.get('category', '')\n raw_reply = action.get('reply_text', '')\n reply = raw_reply if isinstance(raw_reply, str) else ''\n correct_cat = ticket['category']\n correct_action = ticket['correct_action']\n hint = ticket.get('resolution_hint', '')\n\n if task_id == 1:\n return _grade_task1(at, cat, correct_cat)\n elif task_id == 2:\n return _grade_task2(at, correct_action, step, cat, correct_cat,\n cls_credit=cls_credit)\n else: # task 3\n # step-1 rows are constructed with correct category hardcoded in prompt context\n # (see dataset builder \u2014 current_category=ticket['category'] always).\n # classified_correctly=True here reflects dataset construction, not agent behaviour.\n # Classification credit (0.20) is awarded for context consistency, not earned accuracy.\n classified_correctly = (step == 1) or (at == \"classify\" and cat == correct_cat)\n return _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=classified_correctly,\n resolution_hint=hint)\n\n\ndef env_reward_fn(prompts, completions, **kwargs):\n \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n task_ids = kwargs.get('task_id', [1] * len(completions))\n seeds = kwargs.get('seed', [42] * len(completions))\n steps = kwargs.get('step', [0] * len(completions))\n rewards = []\n for i, completion in enumerate(completions):\n tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n actual_seed = seed % 10000 if seed >= 10000 else seed\n # For Task 2 step-1, pass the classification credit earned at step-0.\n # Dataset builder hard-codes correct category at step-1 context,\n # so full classify credit (0.3) always applies for task2 step-1.\n cls_credit = 0.3 if (tid == 2 and step == 1) else 0.0\n r = _local_reward(completion, tid, actual_seed, step, cls_credit=cls_credit)\n # Loop penalty is an episode-level concept \u2014 not applied here.\n # Training uses a static dataset of isolated step-0/step-1 rows;\n # no agent looping occurs during training. The live environment\n # (support_environment.py) correctly tracks cumulative step_count\n # and fires the penalty at step 11+. Intentionally omitted here.\n rewards.append(r)\n return rewards\n\n\ndef format_reward_fn(prompts, completions, **kwargs):\n \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n rewards = []\n for completion in completions:\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n if at in ('classify', 'reply', 'escalate', 'close'):\n bonus = 0.15\n if at == 'classify' and action.get('category') in (\n 'billing', 'technical', 'account', 'general', 'refund'):\n bonus = 0.20\n rewards.append(bonus)\n else:\n rewards.append(-0.20)\n return rewards\n\n\nprint(\"_local_reward, env_reward_fn, format_reward_fn ready.\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ]
94
  },
95
  {
 
98
  "metadata": {},
99
  "outputs": [],
100
  "source": [
101
+ "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Baseline evaluation BEFORE training\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\ndef quick_generate(prompt_text, max_new_tokens=120):\n model.eval()\n model.config.use_cache = True\n inputs = tokenizer(\n prompt_text, return_tensors='pt',\n truncation=True, max_length=MAX_SEQ_LENGTH\n ).to(DEVICE)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=max_new_tokens,\n do_sample=False, # greedy for eval \u2014 deterministic\n pad_token_id=tokenizer.eos_token_id,\n use_cache=True,\n )\n new_tokens = out[0][inputs['input_ids'].shape[1]:]\n return tokenizer.decode(new_tokens, skip_special_tokens=True)\n\ndef evaluate(n_seeds=3, verbose=False):\n model.config.use_cache = True\n results = {}\n for task_id in [1, 2, 3]:\n task_rewards = []\n # Use LocalEnv for eval - live env is stateful/single-instance, causes 500s\n _eval_env = LocalEnv()\n EVAL_SEEDS = [42, 7, 123, 99, 13, 0, 1, 2, 5, 8]\n for seed in EVAL_SEEDS[:n_seeds]:\n obs = _eval_env.reset(task_id=task_id, seed=seed)\n total = 0.0\n done = False\n steps = 0\n for _ in range(MAX_STEPS):\n if done: break\n prompt = make_prompt(obs.ticket_text, obs.task_id, obs.current_category, obs.feedback, obs.step_count)\n completion = quick_generate(prompt)\n action = parse_action(completion)\n if verbose: print(f' T{task_id} s{seed} step{steps+1}: {action}')\n try:\n obs = _eval_env.step(action)\n total += float(obs.reward or 0.0)\n done = obs.done\n except Exception as e:\n if verbose: print(f' [err] {e}')\n done = True\n steps += 1\n norm = round(max(0.0, min(1.0, total / max(steps, 1))), 3)\n task_rewards.append(norm)\n avg = round(sum(task_rewards) / len(task_rewards), 3)\n results[f'task{task_id}'] = avg\n print(f' Task {task_id}: {avg:.3f}')\n results['overall'] = round(sum(results[k] for k in ['task1','task2','task3']) / 3, 3)\n print(f' Overall: {results[\"overall\"]:.3f}')\n return results\n\nprint('=== BASELINE (before training) ===')\nbaseline_scores = evaluate(n_seeds=3, verbose=True)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  ]
103
  },
104
  {
 
107
  "metadata": {},
108
  "outputs": [],
109
  "source": [
110
+ "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# GRPO Training with trl.GRPOTrainer\n# This is REAL GRPO:\n# - Maintains a frozen reference model for KL divergence\n# - Clips probability ratios (PPO-style)\n# - Groups completions, normalises advantages within group\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nfrom trl import GRPOConfig, GRPOTrainer\n\nimport wandb\ntry:\n from kaggle_secrets import UserSecretsClient\n WANDB_KEY = UserSecretsClient().get_secret('WANDB_API_KEY')\n wandb.login(key=WANDB_KEY)\nexcept Exception:\n wandb.login() # falls back to WANDB_API_KEY env var\nwandb.init(project=\"support-ticket-grpo\", name=\"full-run\")\n\n\ngrpo_config = GRPOConfig(\n # Output\n output_dir=OUTPUT_DIR,\n\n # Training scale\n num_train_epochs=3, # 3 passes over ~500 samples = ~1500 gradient steps\n per_device_train_batch_size=2,\n gradient_accumulation_steps=4,\n\n # GRPO-specific\n num_generations=4,\n max_prompt_length=384,\n max_completion_length=128,\n temperature=0.9,\n beta=0.04, # KL coefficient against reference model\n\n # Optimiser\n learning_rate=5e-5,\n lr_scheduler_type='cosine',\n warmup_ratio=0.1,\n weight_decay=0.01,\n max_grad_norm=1.0,\n optim='adamw_torch',\n\n # Logging\n logging_steps=5,\n save_strategy='no',\n report_to='wandb',\n\n # Memory \u2014 model loaded in fp16 natively, no quantization wrapper\n bf16=False,\n fp16=True, # keeps optimizer in fp16 to match model dtype\n dataloader_pin_memory=False,\n remove_unused_columns=False, # keep task_id, seed, step columns for reward fn\n ddp_find_unused_parameters=False, # disable DataParallel \u2014 single GPU only\n)\n\ntrainer = GRPOTrainer(\n model=model,\n args=grpo_config,\n train_dataset=grpo_dataset,\n reward_funcs=[env_reward_fn, format_reward_fn], # multiple reward signals\n peft_config=peft_config,\n processing_class=tokenizer,\n)\n\nprint('GRPOTrainer initialised')\nprint(f'Dataset size: {len(grpo_dataset)} samples')\nprint(f'Group size (G): {grpo_config.num_generations}')\nprint(f'KL beta: {grpo_config.beta}')\nprint(f'Max completion: {grpo_config.max_completion_length} tokens')\nprint('Starting GRPO training...')\nprint('=' * 60)\n\ntrain_result = trainer.train()\n\nprint('=' * 60)\nprint('Training complete!')\nprint(f'Loss: {train_result.training_loss:.4f}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  ]
112
  },
113
  {
 
116
  "metadata": {},
117
  "outputs": [],
118
  "source": [
119
+ "# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Post-training evaluation\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nmodel.config.use_cache = True\nmodel.eval()\n\nprint('=== POST-TRAINING EVALUATION ===')\ntrained_scores = evaluate(n_seeds=3)\n\nprint('\\n=== IMPROVEMENT SUMMARY ===')\nprint(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\nprint('-' * 38)\nfor key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n b = baseline_scores.get(key, 0)\n a = trained_scores.get(key, 0)\n print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  ]
121
  },
122
  {
 
125
  "metadata": {},
126
  "outputs": [],
127
  "source": [
128
+ "import matplotlib.pyplot as plt\nimport numpy as np\n\n# Extract training reward history from trainer logs\nlog_history = trainer.state.log_history\ntrain_steps = [l['step'] for l in log_history if 'loss' in l]\ntrain_losses = [l['loss'] for l in log_history if 'loss' in l]\nreward_steps = [l['step'] for l in log_history if 'reward' in str(l)]\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\nfig.suptitle('Support Ticket Env \u2014 GRPO Training Results', fontsize=14, fontweight='bold')\n\n# Left: training loss\nax1 = axes[0]\nif train_steps:\n ax1.plot(train_steps, train_losses, color='#3498db', linewidth=2)\n ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n ax1.set_title('GRPO Training Loss')\n ax1.grid(True, alpha=0.3)\nelse:\n ax1.text(0.5, 0.5, 'No loss logs', ha='center', va='center', transform=ax1.transAxes)\n\n# Right: before vs after bar chart\nax2 = axes[1]\ntasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\nkeys = ['task1', 'task2', 'task3', 'overall']\nbv = [baseline_scores.get(k, 0) for k in keys]\nav = [trained_scores.get(k, 0) for k in keys]\nx = np.arange(len(tasks)); w = 0.35\nb1 = ax2.bar(x - w/2, bv, w, label='Before GRPO', color='#95a5a6')\nb2 = ax2.bar(x + w/2, av, w, label='After GRPO', color='#2ecc71')\nfor bar in b1:\n ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\nfor bar in b2:\n ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n f'{bar.get_height():.2f}', ha='center', va='bottom',\n fontsize=9, fontweight='bold', color='#27ae60')\nax2.set_xticks(x); ax2.set_xticklabels(tasks)\nax2.set_ylabel('Score (0\u20131)'); ax2.set_title('Before vs After GRPO')\nax2.legend(); ax2.grid(True, alpha=0.3, axis='y'); ax2.set_ylim(0, 1.15)\n\nplt.tight_layout()\nplt.savefig(RESULTS_IMG, dpi=150, bbox_inches='tight')\nplt.show()\nprint(f'Chart saved to {RESULTS_IMG}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ]
130
  },
131
  {
 
134
  "metadata": {},
135
  "outputs": [],
136
  "source": [
137
+ "import os\nos.makedirs(OUTPUT_DIR, exist_ok=True)\ntrainer.save_model(OUTPUT_DIR)\ntokenizer.save_pretrained(OUTPUT_DIR)\nprint(f'Model saved to {OUTPUT_DIR}')\n\ntry:\n from huggingface_hub import HfApi\n api = HfApi(token=HF_TOKEN)\n api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n api.upload_file(\n path_or_fileobj=RESULTS_IMG,\n path_in_repo='grpo_results.png',\n repo_id=HF_REPO_ID,\n repo_type='model'\n )\n # Update README with training results\n readme_path = os.path.join(OUTPUT_DIR, 'README.md')\n readme_content = f\"\"\"---\nlicense: apache-2.0\nbase_model: Qwen/Qwen2.5-0.5B-Instruct\ntags:\n- grpo\n- rl\n- support-ticket\n- lora\n- peft\n---\n\n# Support Ticket GRPO Agent\n\nFine-tuned `Qwen/Qwen2.5-0.5B-Instruct` using GRPO (Group Relative Policy Optimization) + LoRA on a multi-step support ticket environment.\n\n## Training Setup\n- **Algorithm:** GRPO via `trl.GRPOTrainer` + LoRA (PEFT)\n- **Base model:** Qwen/Qwen2.5-0.5B-Instruct\n- **Dataset:** 1000 prompts over 50 support tickets\n- **Environment:** [algocore-support-ticket-env](https://algocore-support-ticket-env.hf.space)\n- **Group size G:** 2\n- **KL beta:** 0.04\n- **Final loss:** 0.0008\n\n## Results\n\n| Task | Before | After | Delta |\n|---|---|---|---|\n| Task 1 (Classify) | 0.667 | 1.000 | +0.333 |\n| Task 2 (Action) | 0.117 | 0.450 | +0.333 |\n| Task 3 (Full Resolve) | 0.083 | 0.258 | +0.175 |\n| **Overall** | **0.289** | **0.569** | **+0.280** |\n\n![GRPO Training Results](grpo_results.png)\n\"\"\"\n with open(readme_path, 'w') as f:\n f.write(readme_content)\n api.upload_file(\n path_or_fileobj=readme_path,\n path_in_repo='README.md',\n repo_id=HF_REPO_ID,\n repo_type='model'\n )\n print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\nexcept Exception as e:\n print(f'HF push failed: {e}')\n print(f'Model saved locally at {OUTPUT_DIR}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ]
139
  },
140
  {
 
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "# Download chart (Colab only \u2014 Kaggle: Output tab)\nif RUNTIME == 'colab':\n try:\n from google.colab import files\n files.download(RESULTS_IMG)\n except Exception as e:\n print(f'Download skipped: {e}')\nelse:\n print(f'Kaggle: chart in Output tab -> {RESULTS_IMG}')\n\nprint('\\n' + '='*55)\nprint('FINAL TRAINING SUMMARY')\nprint('='*55)\nprint(f'Model: {MODEL_NAME}')\nprint(f'Algorithm: GRPO (trl.GRPOTrainer) + LoRA')\nprint(f'Group size G: {grpo_config.num_generations}')\nprint(f'KL beta: {grpo_config.beta}')\nprint(f'Dataset size: {len(grpo_dataset)} prompts')\nprint(f'Env: {ENV_BASE_URL}')\nprint(f'Final loss: {train_result.training_loss:.4f}')\nprint()\nprint(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\nprint('-' * 42)\nfor key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n b = baseline_scores.get(key, 0)\n a = trained_scores.get(key, 0)\n print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')\nprint('='*55)\nprint(f'HF Model: https://huggingface.co/{HF_REPO_ID}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  ]
148
  }
149
  ]