Spaces:
Sleeping
Sleeping
Vighnesh commited on
Commit ·
5648ca2
1
Parent(s): 7bdf1e0
add training notebook
Browse files- train_grpo.ipynb +15 -787
train_grpo.ipynb
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
"cell_type": "markdown",
|
| 17 |
"metadata": {},
|
| 18 |
"source": [
|
| 19 |
-
"# Support Ticket Env
|
| 20 |
"**OpenEnv x Scalar Hackathon**\n",
|
| 21 |
"\n",
|
| 22 |
"Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using **real GRPO** (`trl.GRPOTrainer`) + LoRA (PEFT)\n",
|
|
@@ -26,7 +26,7 @@
|
|
| 26 |
"- **Algorithm:** GRPO via `trl.GRPOTrainer` (proper clipped ratio + KL vs reference model)\n",
|
| 27 |
"- **Environment:** https://algocore-support-ticket-env.hf.space\n",
|
| 28 |
"- **Runtime:** ~30-45 min on Kaggle P100/T4 (or Colab)\n",
|
| 29 |
-
"- **No Unsloth**
|
| 30 |
]
|
| 31 |
},
|
| 32 |
{
|
|
@@ -35,10 +35,7 @@
|
|
| 35 |
"metadata": {},
|
| 36 |
"outputs": [],
|
| 37 |
"source": [
|
| 38 |
-
"# Install dependencies\n
|
| 39 |
-
"!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n",
|
| 40 |
-
"!pip install -q bitsandbytes requests matplotlib\n",
|
| 41 |
-
"print('Installation complete')"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
{
|
|
@@ -47,52 +44,7 @@
|
|
| 47 |
"metadata": {},
|
| 48 |
"outputs": [],
|
| 49 |
"source": [
|
| 50 |
-
"import os\n
|
| 51 |
-
"\n",
|
| 52 |
-
"# Load HF_TOKEN: Colab -> Kaggle -> env var\n",
|
| 53 |
-
"HF_TOKEN = ''\n",
|
| 54 |
-
"try:\n",
|
| 55 |
-
" from google.colab import userdata\n",
|
| 56 |
-
" HF_TOKEN = userdata.get('HF_TOKEN') or ''\n",
|
| 57 |
-
"except Exception:\n",
|
| 58 |
-
" pass\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"if not HF_TOKEN:\n",
|
| 61 |
-
" try:\n",
|
| 62 |
-
" from kaggle_secrets import UserSecretsClient\n",
|
| 63 |
-
" HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n",
|
| 64 |
-
" except Exception:\n",
|
| 65 |
-
" pass\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"if not HF_TOKEN:\n",
|
| 68 |
-
" HF_TOKEN = os.environ.get('HF_TOKEN', '')\n",
|
| 69 |
-
"\n",
|
| 70 |
-
"if not HF_TOKEN:\n",
|
| 71 |
-
" raise ValueError('HF_TOKEN not found. Kaggle: Add-ons -> Secrets -> HF_TOKEN. Colab: key icon -> Secrets.')\n",
|
| 72 |
-
"\n",
|
| 73 |
-
"print('HF_TOKEN loaded OK')\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
|
| 76 |
-
"MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
| 77 |
-
"# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n",
|
| 78 |
-
"# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n",
|
| 79 |
-
"# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\n",
|
| 80 |
-
"HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
|
| 81 |
-
"\n",
|
| 82 |
-
"RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
|
| 83 |
-
"OUTPUT_DIR = '/kaggle/working/support-ticket-grpo' if RUNTIME == 'kaggle' else '/content/support-ticket-grpo'\n",
|
| 84 |
-
"RESULTS_IMG = '/kaggle/working/grpo_results.png' if RUNTIME == 'kaggle' else '/content/grpo_results.png'\n",
|
| 85 |
-
"print(f'Runtime: {RUNTIME} | Output: {OUTPUT_DIR}')\n",
|
| 86 |
-
"\n",
|
| 87 |
-
"os.environ['HF_TOKEN'] = HF_TOKEN\n",
|
| 88 |
-
"os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n",
|
| 89 |
-
"\n",
|
| 90 |
-
"import torch\n",
|
| 91 |
-
"print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU — switch runtime!')\n",
|
| 92 |
-
"if torch.cuda.is_available():\n",
|
| 93 |
-
" print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), 'GB')\n",
|
| 94 |
-
"print('Model:', MODEL_NAME)\n",
|
| 95 |
-
"print('Env: ', ENV_BASE_URL)"
|
| 96 |
]
|
| 97 |
},
|
| 98 |
{
|
|
@@ -101,135 +53,7 @@
|
|
| 101 |
"metadata": {},
|
| 102 |
"outputs": [],
|
| 103 |
"source": [
|
| 104 |
-
"import requests, json, re, random\n",
|
| 105 |
-
"from dataclasses import dataclass\n",
|
| 106 |
-
"from typing import Optional\n",
|
| 107 |
-
"\n",
|
| 108 |
-
"TICKETS = [\n",
|
| 109 |
-
" {'id':'T001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
|
| 110 |
-
" {'id':'T002','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
|
| 111 |
-
" {'id':'T003','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
|
| 112 |
-
" {'id':'T004','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
|
| 113 |
-
" {'id':'T005','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
|
| 114 |
-
" {'id':'T006','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
|
| 115 |
-
" {'id':'T007','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
|
| 116 |
-
" {'id':'T008','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
|
| 117 |
-
" {'id':'T009','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
|
| 118 |
-
" {'id':'T010','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
|
| 119 |
-
"]\n",
|
| 120 |
-
"\n",
|
| 121 |
-
"KEYWORD_REWARDS = {\n",
|
| 122 |
-
" 'billing': ['charge','invoice','payment','billing','refund'],\n",
|
| 123 |
-
" 'account': ['password','login','account','cancel','subscription'],\n",
|
| 124 |
-
" 'technical': ['engineering','escalate','bug','crash','error'],\n",
|
| 125 |
-
" 'refund': ['refund','return','credit','process'],\n",
|
| 126 |
-
" 'general': ['hours','contact','phone','information','help'],\n",
|
| 127 |
-
"}\n",
|
| 128 |
-
"\n",
|
| 129 |
-
"@dataclass\n",
|
| 130 |
-
"class Obs:\n",
|
| 131 |
-
" ticket_id: str\n",
|
| 132 |
-
" ticket_text: str\n",
|
| 133 |
-
" task_id: int\n",
|
| 134 |
-
" current_category: Optional[str]\n",
|
| 135 |
-
" resolved: bool\n",
|
| 136 |
-
" step_count: int\n",
|
| 137 |
-
" feedback: str\n",
|
| 138 |
-
" score: float\n",
|
| 139 |
-
" reward: float\n",
|
| 140 |
-
" done: bool\n",
|
| 141 |
-
"\n",
|
| 142 |
-
"class LocalEnv:\n",
|
| 143 |
-
" \"\"\"Local mirror of live HF Space — same reward logic, used as fallback.\"\"\"\n",
|
| 144 |
-
" def reset(self, task_id=1, seed=42):\n",
|
| 145 |
-
" rng = random.Random(seed)\n",
|
| 146 |
-
" self.task_id = task_id\n",
|
| 147 |
-
" self.ticket = rng.choice(TICKETS)\n",
|
| 148 |
-
" self.classified = False\n",
|
| 149 |
-
" self.step_count = 0\n",
|
| 150 |
-
" return Obs(self.ticket['id'], self.ticket['text'], task_id,\n",
|
| 151 |
-
" None, False, 0, 'New ticket. Take action.', 0.0, 0.0, False)\n",
|
| 152 |
-
" def step(self, action):\n",
|
| 153 |
-
" self.step_count += 1\n",
|
| 154 |
-
" at = action.get('action_type', '')\n",
|
| 155 |
-
" cat = action.get('category', '')\n",
|
| 156 |
-
" reply = action.get('reply_text', '')\n",
|
| 157 |
-
" reward = 0.0; done = False\n",
|
| 158 |
-
" if self.task_id == 1:\n",
|
| 159 |
-
" reward = 1.0 if cat == self.ticket['category'] else 0.0\n",
|
| 160 |
-
" done = True\n",
|
| 161 |
-
" elif self.task_id == 2:\n",
|
| 162 |
-
" if not self.classified:\n",
|
| 163 |
-
" reward = 0.3 if cat == self.ticket['category'] else 0.1\n",
|
| 164 |
-
" self.classified = True\n",
|
| 165 |
-
" else:\n",
|
| 166 |
-
" reward = 1.0 if at == self.ticket['correct_action'] else 0.0\n",
|
| 167 |
-
" done = True\n",
|
| 168 |
-
" else:\n",
|
| 169 |
-
" if not self.classified:\n",
|
| 170 |
-
" reward = 0.2 if cat == self.ticket['category'] else 0.0\n",
|
| 171 |
-
" self.classified = True\n",
|
| 172 |
-
" else:\n",
|
| 173 |
-
" action_score = 0.4 if at == self.ticket['correct_action'] else 0.0\n",
|
| 174 |
-
" kws = KEYWORD_REWARDS.get(self.ticket['category'], [])\n",
|
| 175 |
-
" reply_score = min(0.25, sum(0.05 for kw in kws if kw in reply.lower()))\n",
|
| 176 |
-
" reward = action_score + reply_score\n",
|
| 177 |
-
" done = True\n",
|
| 178 |
-
" return Obs(self.ticket['id'], self.ticket['text'], self.task_id,\n",
|
| 179 |
-
" self.ticket['category'] if self.classified else None,\n",
|
| 180 |
-
" done, self.step_count, f'reward={reward:.2f}', reward, reward, done)\n",
|
| 181 |
-
"\n",
|
| 182 |
-
"class RemoteEnv:\n",
|
| 183 |
-
" \"\"\"Live HF Space API.\"\"\"\n",
|
| 184 |
-
" def __init__(self, base_url):\n",
|
| 185 |
-
" self.base_url = base_url.rstrip('/')\n",
|
| 186 |
-
" self.session = requests.Session()\n",
|
| 187 |
-
" self.session.headers.update({'Content-Type': 'application/json'})\n",
|
| 188 |
-
" def health(self):\n",
|
| 189 |
-
" try:\n",
|
| 190 |
-
" r = self.session.get(f'{self.base_url}/health', timeout=8)\n",
|
| 191 |
-
" return r.status_code == 200\n",
|
| 192 |
-
" except: return False\n",
|
| 193 |
-
" def reset(self, task_id=1, seed=42):\n",
|
| 194 |
-
" r = self.session.post(f'{self.base_url}/reset', json={'task_id': task_id, 'seed': seed}, timeout=15)\n",
|
| 195 |
-
" r.raise_for_status()\n",
|
| 196 |
-
" obs = r.json().get('observation', r.json())\n",
|
| 197 |
-
" return self._parse_obs(obs)\n",
|
| 198 |
-
" def step(self, action):\n",
|
| 199 |
-
" r = self.session.post(f'{self.base_url}/step', json={'action': action}, timeout=15)\n",
|
| 200 |
-
" r.raise_for_status()\n",
|
| 201 |
-
" obs = r.json().get('observation', r.json())\n",
|
| 202 |
-
" return self._parse_obs(obs)\n",
|
| 203 |
-
" def _parse_obs(self, obs):\n",
|
| 204 |
-
" # Safely coerce each field — avoids 'Field' object errors from dataclass defaults\n",
|
| 205 |
-
" fields = Obs.__dataclass_fields__\n",
|
| 206 |
-
" def safe(k, fallback):\n",
|
| 207 |
-
" v = obs.get(k, fallback)\n",
|
| 208 |
-
" if isinstance(v, type): return fallback # guard against dataclass Field objects\n",
|
| 209 |
-
" return v\n",
|
| 210 |
-
" return Obs(\n",
|
| 211 |
-
" ticket_id=safe('ticket_id', ''),\n",
|
| 212 |
-
" ticket_text=safe('ticket_text', ''),\n",
|
| 213 |
-
" task_id=int(safe('task_id', 1)),\n",
|
| 214 |
-
" current_category=safe('current_category', None),\n",
|
| 215 |
-
" resolved=bool(safe('resolved', False)),\n",
|
| 216 |
-
" step_count=int(safe('step_count', 0)),\n",
|
| 217 |
-
" feedback=safe('feedback', ''),\n",
|
| 218 |
-
" score=float(safe('score', 0.0)),\n",
|
| 219 |
-
" reward=float(safe('reward', 0.0)),\n",
|
| 220 |
-
" done=bool(safe('done', False)),\n",
|
| 221 |
-
" )\n",
|
| 222 |
-
"\n",
|
| 223 |
-
"_remote = RemoteEnv(ENV_BASE_URL)\n",
|
| 224 |
-
"if _remote.health():\n",
|
| 225 |
-
" env_client = _remote\n",
|
| 226 |
-
" print('Using LIVE environment:', ENV_BASE_URL)\n",
|
| 227 |
-
"else:\n",
|
| 228 |
-
" env_client = LocalEnv()\n",
|
| 229 |
-
" print('Live API unreachable — using LOCAL mirror')\n",
|
| 230 |
-
"\n",
|
| 231 |
-
"obs = env_client.reset(task_id=1, seed=42)\n",
|
| 232 |
-
"print(f'Ticket: {obs.ticket_id} — {obs.ticket_text[:60]}')"
|
| 233 |
]
|
| 234 |
},
|
| 235 |
{
|
|
@@ -238,39 +62,7 @@
|
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
| 241 |
-
"import torch\n
|
| 242 |
-
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 243 |
-
"from peft import LoraConfig, TaskType\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"MAX_SEQ_LENGTH = 512\n",
|
| 246 |
-
"print(f'Loading {MODEL_NAME}...')\n",
|
| 247 |
-
"\n",
|
| 248 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n",
|
| 249 |
-
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 250 |
-
"tokenizer.padding_side = 'left'\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"# Qwen2.5-0.5B = ~1GB in fp16 — fits easily in 15.6GB T4, no quantization needed\n",
|
| 253 |
-
"# bitsandbytes 4-bit + DataParallel + gradient checkpointing = CUDA illegal memory access\n",
|
| 254 |
-
"DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
|
| 255 |
-
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 256 |
-
" MODEL_NAME,\n",
|
| 257 |
-
" dtype=torch.float16,\n",
|
| 258 |
-
" device_map={'': 0},\n",
|
| 259 |
-
" token=HF_TOKEN,\n",
|
| 260 |
-
")\n",
|
| 261 |
-
"model.config.use_cache = False\n",
|
| 262 |
-
"\n",
|
| 263 |
-
"peft_config = LoraConfig(\n",
|
| 264 |
-
" task_type=TaskType.CAUSAL_LM,\n",
|
| 265 |
-
" r=16,\n",
|
| 266 |
-
" lora_alpha=32,\n",
|
| 267 |
-
" target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
|
| 268 |
-
" lora_dropout=0.05,\n",
|
| 269 |
-
" bias='none',\n",
|
| 270 |
-
")\n",
|
| 271 |
-
"\n",
|
| 272 |
-
"print('Model loaded — LoRA config ready (GRPOTrainer will apply PEFT internally)')\n",
|
| 273 |
-
"print(f'Model params: {sum(p.numel() for p in model.parameters()):,}')"
|
| 274 |
]
|
| 275 |
},
|
| 276 |
{
|
|
@@ -279,64 +71,7 @@
|
|
| 279 |
"metadata": {},
|
| 280 |
"outputs": [],
|
| 281 |
"source": [
|
| 282 |
-
"SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n",
|
| 283 |
-
"\n",
|
| 284 |
-
"VALID action_type values: classify, reply, escalate, close\n",
|
| 285 |
-
"VALID category values: billing, technical, account, general, refund\n",
|
| 286 |
-
"\n",
|
| 287 |
-
"For classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\n",
|
| 288 |
-
"For reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\n",
|
| 289 |
-
"For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n",
|
| 290 |
-
"For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n",
|
| 291 |
-
"\n",
|
| 292 |
-
"RULES:\n",
|
| 293 |
-
"- task_id=1: ALWAYS output action_type=classify first\n",
|
| 294 |
-
"- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n",
|
| 295 |
-
"- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n",
|
| 296 |
-
"- technical/crash/error/bug tickets -> escalate\n",
|
| 297 |
-
"- thank you/resolved tickets -> close\n",
|
| 298 |
-
"- billing/account/refund/general -> reply\n",
|
| 299 |
-
"- DO NOT use action_type=respond or action_type=resolve — those are INVALID'''\n",
|
| 300 |
-
"\n",
|
| 301 |
-
"def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n",
|
| 302 |
-
" user_msg = json.dumps({\n",
|
| 303 |
-
" 'ticket': ticket_text,\n",
|
| 304 |
-
" 'task_id': task_id,\n",
|
| 305 |
-
" 'current_category': current_category,\n",
|
| 306 |
-
" 'feedback': feedback,\n",
|
| 307 |
-
" 'step': step,\n",
|
| 308 |
-
" })\n",
|
| 309 |
-
" messages = [\n",
|
| 310 |
-
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 311 |
-
" {'role': 'user', 'content': user_msg},\n",
|
| 312 |
-
" ]\n",
|
| 313 |
-
" return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 314 |
-
"\n",
|
| 315 |
-
"def parse_action(text):\n",
|
| 316 |
-
" text = text.strip()\n",
|
| 317 |
-
" # Strip markdown code blocks\n",
|
| 318 |
-
" text = re.sub(r'^```(?:json)?\\s*', '', text)\n",
|
| 319 |
-
" text = re.sub(r'\\s*```$', '', text.strip())\n",
|
| 320 |
-
" try:\n",
|
| 321 |
-
" return json.loads(text)\n",
|
| 322 |
-
" except Exception:\n",
|
| 323 |
-
" match = re.search(r'\\{[^{}]*\\}', text, re.DOTALL)\n",
|
| 324 |
-
" if match:\n",
|
| 325 |
-
" try: return json.loads(match.group())\n",
|
| 326 |
-
" except: pass\n",
|
| 327 |
-
" return {'action_type': 'classify', 'category': 'general'}\n",
|
| 328 |
-
"\n",
|
| 329 |
-
"def _safe_parse(completion):\n",
|
| 330 |
-
" \"\"\"Always returns a dict, never a string.\"\"\"\n",
|
| 331 |
-
" result = parse_action(completion) if isinstance(completion, str) else {}\n",
|
| 332 |
-
" if not isinstance(result, dict):\n",
|
| 333 |
-
" return {'action_type': '', 'category': '', 'reply_text': ''}\n",
|
| 334 |
-
" return result\n",
|
| 335 |
-
"\n",
|
| 336 |
-
"print('Prompt builder OK')\n",
|
| 337 |
-
"# Quick sanity check\n",
|
| 338 |
-
"sample = make_prompt('I was charged twice', task_id=1)\n",
|
| 339 |
-
"print('Sample prompt length (chars):', len(sample))"
|
| 340 |
]
|
| 341 |
},
|
| 342 |
{
|
|
@@ -345,146 +80,7 @@
|
|
| 345 |
"metadata": {},
|
| 346 |
"outputs": [],
|
| 347 |
"source": [
|
| 348 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 349 |
-
"# Build LARGE dataset for GRPOTrainer\n",
|
| 350 |
-
"# Strategy:\n",
|
| 351 |
-
"# 1. Expanded ticket bank (50 tickets across all categories)\n",
|
| 352 |
-
"# 2. All 3 task types x many seeds\n",
|
| 353 |
-
"# 3. Multi-step contexts: step-0 (classify) AND step-1 (resolve)\n",
|
| 354 |
-
"# 4. Paraphrase augmentation of ticket text\n",
|
| 355 |
-
"# Target: ~500+ training samples\n",
|
| 356 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 357 |
-
"from datasets import Dataset\n",
|
| 358 |
-
"\n",
|
| 359 |
-
"MAX_STEPS = 6\n",
|
| 360 |
-
"TASK_IDS = [1, 2, 3]\n",
|
| 361 |
-
"\n",
|
| 362 |
-
"# Large seed pool\n",
|
| 363 |
-
"SEEDS = list(range(0, 200)) # 200 seeds\n",
|
| 364 |
-
"\n",
|
| 365 |
-
"# Expanded ticket bank — 50 tickets covering all categories\n",
|
| 366 |
-
"ALL_TICKETS = [\n",
|
| 367 |
-
" # billing (12)\n",
|
| 368 |
-
" {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
|
| 369 |
-
" {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
|
| 370 |
-
" {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
|
| 371 |
-
" {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n",
|
| 372 |
-
" {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n",
|
| 373 |
-
" {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n",
|
| 374 |
-
" {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n",
|
| 375 |
-
" {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n",
|
| 376 |
-
" {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n",
|
| 377 |
-
" {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n",
|
| 378 |
-
" {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n",
|
| 379 |
-
" {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n",
|
| 380 |
-
" # account (10)\n",
|
| 381 |
-
" {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
|
| 382 |
-
" {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
|
| 383 |
-
" {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n",
|
| 384 |
-
" {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n",
|
| 385 |
-
" {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n",
|
| 386 |
-
" {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n",
|
| 387 |
-
" {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n",
|
| 388 |
-
" {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n",
|
| 389 |
-
" {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n",
|
| 390 |
-
" {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n",
|
| 391 |
-
" # technical (10)\n",
|
| 392 |
-
" {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
|
| 393 |
-
" {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
|
| 394 |
-
" {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n",
|
| 395 |
-
" {'id':'T004','text':'Export to CSV is broken — it downloads an empty file.','category':'technical','correct_action':'escalate'},\n",
|
| 396 |
-
" {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n",
|
| 397 |
-
" {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n",
|
| 398 |
-
" {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n",
|
| 399 |
-
" {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n",
|
| 400 |
-
" {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n",
|
| 401 |
-
" {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n",
|
| 402 |
-
" # refund (8)\n",
|
| 403 |
-
" {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
|
| 404 |
-
" {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n",
|
| 405 |
-
" {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n",
|
| 406 |
-
" {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n",
|
| 407 |
-
" {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n",
|
| 408 |
-
" {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n",
|
| 409 |
-
" {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n",
|
| 410 |
-
" {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n",
|
| 411 |
-
" # general (10)\n",
|
| 412 |
-
" {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
|
| 413 |
-
" {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
|
| 414 |
-
" {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n",
|
| 415 |
-
" {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n",
|
| 416 |
-
" {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n",
|
| 417 |
-
" {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n",
|
| 418 |
-
" {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n",
|
| 419 |
-
" {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n",
|
| 420 |
-
" {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n",
|
| 421 |
-
" {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n",
|
| 422 |
-
"]\n",
|
| 423 |
-
"\n",
|
| 424 |
-
"KEYWORD_REWARDS_FULL = {\n",
|
| 425 |
-
" 'billing': ['charge','invoice','payment','billing','refund','receipt'],\n",
|
| 426 |
-
" 'account': ['password','login','account','cancel','subscription','email'],\n",
|
| 427 |
-
" 'technical': ['engineering','escalate','bug','crash','error','fix'],\n",
|
| 428 |
-
" 'refund': ['refund','return','credit','process','reimburse'],\n",
|
| 429 |
-
" 'general': ['hours','contact','phone','information','help','available'],\n",
|
| 430 |
-
"}\n",
|
| 431 |
-
"\n",
|
| 432 |
-
"def build_grpo_dataset():\n",
|
| 433 |
-
" rows = []\n",
|
| 434 |
-
" rng = random.Random(2026)\n",
|
| 435 |
-
"\n",
|
| 436 |
-
" for task_id in TASK_IDS:\n",
|
| 437 |
-
" for seed in SEEDS:\n",
|
| 438 |
-
" # Pick a ticket deterministically from expanded bank\n",
|
| 439 |
-
" ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
|
| 440 |
-
"\n",
|
| 441 |
-
" # --- Step 0: classify context ---\n",
|
| 442 |
-
" prompt_step0 = make_prompt(\n",
|
| 443 |
-
" ticket_text=ticket['text'],\n",
|
| 444 |
-
" task_id=task_id,\n",
|
| 445 |
-
" current_category=None,\n",
|
| 446 |
-
" feedback='New ticket. Classify it first.',\n",
|
| 447 |
-
" step=0,\n",
|
| 448 |
-
" )\n",
|
| 449 |
-
" rows.append({\n",
|
| 450 |
-
" 'prompt': prompt_step0,\n",
|
| 451 |
-
" 'ticket_text': ticket['text'],\n",
|
| 452 |
-
" 'task_id': task_id,\n",
|
| 453 |
-
" 'seed': seed,\n",
|
| 454 |
-
" 'step': 0,\n",
|
| 455 |
-
" })\n",
|
| 456 |
-
"\n",
|
| 457 |
-
" # --- Step 1: resolve context (tasks 2 & 3 only) ---\n",
|
| 458 |
-
" if task_id in (2, 3):\n",
|
| 459 |
-
" prompt_step1 = make_prompt(\n",
|
| 460 |
-
" ticket_text=ticket['text'],\n",
|
| 461 |
-
" task_id=task_id,\n",
|
| 462 |
-
" current_category=ticket['category'],\n",
|
| 463 |
-
" feedback=f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n",
|
| 464 |
-
" step=1,\n",
|
| 465 |
-
" )\n",
|
| 466 |
-
" rows.append({\n",
|
| 467 |
-
" 'prompt': prompt_step1,\n",
|
| 468 |
-
" 'ticket_text': ticket['text'],\n",
|
| 469 |
-
" 'task_id': task_id,\n",
|
| 470 |
-
" 'seed': seed + 10000, # unique seed key for step-1\n",
|
| 471 |
-
" 'step': 1,\n",
|
| 472 |
-
" })\n",
|
| 473 |
-
"\n",
|
| 474 |
-
" # Shuffle so tasks/steps are interleaved during training\n",
|
| 475 |
-
" rng.shuffle(rows)\n",
|
| 476 |
-
" return Dataset.from_list(rows)\n",
|
| 477 |
-
"\n",
|
| 478 |
-
"grpo_dataset = build_grpo_dataset()\n",
|
| 479 |
-
"print(f'Dataset built: {len(grpo_dataset)} samples')\n",
|
| 480 |
-
"# breakdown\n",
|
| 481 |
-
"from collections import Counter\n",
|
| 482 |
-
"task_counts = Counter(grpo_dataset['task_id'])\n",
|
| 483 |
-
"step_counts = Counter(grpo_dataset['step'])\n",
|
| 484 |
-
"print(f' By task: {dict(task_counts)}')\n",
|
| 485 |
-
"print(f' By step: {dict(step_counts)}')\n",
|
| 486 |
-
"print('Sample prompt (first 300 chars):')\n",
|
| 487 |
-
"print(grpo_dataset[0]['prompt'][:300])"
|
| 488 |
]
|
| 489 |
},
|
| 490 |
{
|
|
@@ -493,142 +89,7 @@
|
|
| 493 |
"metadata": {},
|
| 494 |
"outputs": [],
|
| 495 |
"source": [
|
| 496 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 497 |
-
"# Reward functions — exact mirror of graders.py\n",
|
| 498 |
-
"# grade_task1 / grade_task2 / grade_task3 / loop_penalty\n",
|
| 499 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 500 |
-
"\n",
|
| 501 |
-
"# Partial-credit action pairs (from graders.py)\n",
|
| 502 |
-
"_PARTIAL_CREDIT_PAIRS = {frozenset({'reply', 'escalate'})}\n",
|
| 503 |
-
"\n",
|
| 504 |
-
"# Keyword lists (from graders.py)\n",
|
| 505 |
-
"_KEYWORD_REWARDS = {\n",
|
| 506 |
-
" 'billing': ['refund', 'charge', 'invoice', 'payment', 'billing'],\n",
|
| 507 |
-
" 'account': ['password', 'login', 'account', 'cancel', 'subscription'],\n",
|
| 508 |
-
" 'technical': ['engineering', 'escalate', 'bug', 'crash', 'error', 'fix'],\n",
|
| 509 |
-
" 'refund': ['refund', 'return', 'credit', 'process'],\n",
|
| 510 |
-
" 'general': ['hours', 'contact', 'phone', 'information', 'help'],\n",
|
| 511 |
-
"}\n",
|
| 512 |
-
"\n",
|
| 513 |
-
"def _reply_quality(reply_text, category):\n",
|
| 514 |
-
" \"\"\"Exact copy of graders._reply_quality: 0.0–0.5 keyword score.\"\"\"\n",
|
| 515 |
-
" if not reply_text: return 0.0\n",
|
| 516 |
-
" hits = sum(1 for kw in _KEYWORD_REWARDS.get(category, []) if kw in reply_text.lower())\n",
|
| 517 |
-
" return min(0.5, hits * 0.1)\n",
|
| 518 |
-
"\n",
|
| 519 |
-
"def _grade_task1(at, cat, correct_cat):\n",
|
| 520 |
-
" \"\"\"Exact copy of graders.grade_task1.\"\"\"\n",
|
| 521 |
-
" return 1.0 if (at == 'classify' and cat == correct_cat) else 0.0\n",
|
| 522 |
-
"\n",
|
| 523 |
-
"def _grade_task2(at, correct_action, step, cat, correct_cat):\n",
|
| 524 |
-
" \"\"\"Exact copy of graders.grade_task2 + classify step.\"\"\"\n",
|
| 525 |
-
" if step == 0:\n",
|
| 526 |
-
" # classify step: partial credit for correct category\n",
|
| 527 |
-
" if at == 'classify' and cat == correct_cat: return 0.3\n",
|
| 528 |
-
" if at == 'classify': return 0.1\n",
|
| 529 |
-
" return 0.0\n",
|
| 530 |
-
" # action step\n",
|
| 531 |
-
" if at == correct_action: return 1.0\n",
|
| 532 |
-
" if frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS: return 0.5\n",
|
| 533 |
-
" if at == 'close': return 0.0\n",
|
| 534 |
-
" return 0.0\n",
|
| 535 |
-
"\n",
|
| 536 |
-
"def _grade_task3(at, cat, correct_cat, correct_action, reply, step, steps_taken=2, max_steps=5):\n",
|
| 537 |
-
" \"\"\"Exact copy of graders.grade_task3.\"\"\"\n",
|
| 538 |
-
" if step == 0:\n",
|
| 539 |
-
" # classification step only\n",
|
| 540 |
-
" return 0.20 if (at == 'classify' and cat == correct_cat) else 0.0\n",
|
| 541 |
-
" # resolution step: 0.40 action + up to 0.50 reply + 0.15 efficiency\n",
|
| 542 |
-
" score = 0.0\n",
|
| 543 |
-
" classified_correctly = True # step-1 means step-0 already happened\n",
|
| 544 |
-
" score += 0.20 # classification credit carried from step 0\n",
|
| 545 |
-
" action_correct = (at == correct_action)\n",
|
| 546 |
-
" action_partial = (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n",
|
| 547 |
-
" if action_correct: score += 0.40\n",
|
| 548 |
-
" elif action_partial: score += 0.20\n",
|
| 549 |
-
" score += _reply_quality(reply, cat) # 0.0–0.5\n",
|
| 550 |
-
" # efficiency bonus (assume 2 steps taken for step-1 samples)\n",
|
| 551 |
-
" resolved = action_correct or action_partial\n",
|
| 552 |
-
" if resolved and steps_taken <= max_steps:\n",
|
| 553 |
-
" efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n",
|
| 554 |
-
" score += 0.15 * efficiency\n",
|
| 555 |
-
" return round(min(1.0, score), 4)\n",
|
| 556 |
-
"\n",
|
| 557 |
-
"def _loop_penalty(step_count, max_steps=10):\n",
|
| 558 |
-
" \"\"\"Exact copy of graders.loop_penalty.\"\"\"\n",
|
| 559 |
-
" return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n",
|
| 560 |
-
"\n",
|
| 561 |
-
"def _local_reward(completion, task_id, seed, step=0):\n",
|
| 562 |
-
" \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n",
|
| 563 |
-
" ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
|
| 564 |
-
" action = _safe_parse(completion)\n",
|
| 565 |
-
" if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
|
| 566 |
-
" at = action.get('action_type', '')\n",
|
| 567 |
-
" cat = action.get('category', '')\n",
|
| 568 |
-
" reply = action.get('reply_text', '') or ''\n",
|
| 569 |
-
" correct_cat = ticket['category']\n",
|
| 570 |
-
" correct_action = ticket['correct_action']\n",
|
| 571 |
-
"\n",
|
| 572 |
-
" if task_id == 1:\n",
|
| 573 |
-
" return _grade_task1(at, cat, correct_cat)\n",
|
| 574 |
-
" elif task_id == 2:\n",
|
| 575 |
-
" return _grade_task2(at, correct_action, step, cat, correct_cat)\n",
|
| 576 |
-
" else: # task 3\n",
|
| 577 |
-
" return _grade_task3(at, cat, correct_cat, correct_action, reply, step)\n",
|
| 578 |
-
"\n",
|
| 579 |
-
"def env_reward_fn(prompts, completions, **kwargs):\n",
|
| 580 |
-
" \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n",
|
| 581 |
-
" task_ids = kwargs.get('task_id', [1] * len(completions))\n",
|
| 582 |
-
" seeds = kwargs.get('seed', [42] * len(completions))\n",
|
| 583 |
-
" steps = kwargs.get('step', [0] * len(completions))\n",
|
| 584 |
-
" rewards = []\n",
|
| 585 |
-
" for i, completion in enumerate(completions):\n",
|
| 586 |
-
" tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n",
|
| 587 |
-
" seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n",
|
| 588 |
-
" step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n",
|
| 589 |
-
" actual_seed = seed % 10000 if seed >= 10000 else seed\n",
|
| 590 |
-
" r = _local_reward(completion, tid, actual_seed, step)\n",
|
| 591 |
-
" # apply loop penalty if step is high\n",
|
| 592 |
-
" r += _loop_penalty(step)\n",
|
| 593 |
-
" rewards.append(r)\n",
|
| 594 |
-
" return rewards\n",
|
| 595 |
-
"\n",
|
| 596 |
-
"def format_reward_fn(prompts, completions, **kwargs):\n",
|
| 597 |
-
" \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n",
|
| 598 |
-
" rewards = []\n",
|
| 599 |
-
" for completion in completions:\n",
|
| 600 |
-
" action = _safe_parse(completion)\n",
|
| 601 |
-
" if not isinstance(action, dict): action = {'action_type': '', 'category': '', 'reply_text': ''}\n",
|
| 602 |
-
" at = action.get('action_type', '')\n",
|
| 603 |
-
" if at in ('classify', 'reply', 'escalate', 'close'):\n",
|
| 604 |
-
" bonus = 0.15\n",
|
| 605 |
-
" if at == 'classify' and action.get('category') in ('billing','technical','account','general','refund'):\n",
|
| 606 |
-
" bonus = 0.20\n",
|
| 607 |
-
" rewards.append(bonus)\n",
|
| 608 |
-
" else:\n",
|
| 609 |
-
" rewards.append(-0.20)\n",
|
| 610 |
-
" return rewards\n",
|
| 611 |
-
"\n",
|
| 612 |
-
"# Print ticket map\n",
|
| 613 |
-
"print('Reward functions synced to graders.py')\n",
|
| 614 |
-
"print('Ticket map (seed % len):')\n",
|
| 615 |
-
"for _i in range(6):\n",
|
| 616 |
-
" _tt = ALL_TICKETS[_i]\n",
|
| 617 |
-
" print(f' [{_i}] {_tt[\"id\"]} cat={_tt[\"category\"]} action={_tt[\"correct_action\"]}')\n",
|
| 618 |
-
"\n",
|
| 619 |
-
"# Sanity: seed=0->B001(billing,reply), seed=22->T001(technical,escalate)\n",
|
| 620 |
-
"_t0 = ALL_TICKETS[0] # B001 billing reply\n",
|
| 621 |
-
"_t22 = ALL_TICKETS[22] # T001 technical escalate\n",
|
| 622 |
-
"r1 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 1, 0, 0)\n",
|
| 623 |
-
"r2 = _local_reward(json.dumps({'action_type':'classify','category':_t0['category']}), 2, 0, 0)\n",
|
| 624 |
-
"r3 = _local_reward(json.dumps({'action_type':'escalate'}), 2, 0, 1)\n",
|
| 625 |
-
"r4 = _local_reward(json.dumps({'action_type':_t22['correct_action'],'reply_text':'escalating this crash bug error to engineering team for a fix'}), 3, 22, 1)\n",
|
| 626 |
-
"r5 = format_reward_fn(prompts=['x'], completions=[json.dumps({'action_type':'respond'})])[0]\n",
|
| 627 |
-
"print(f'task1 correct classify: {r1} (expect 1.0)')\n",
|
| 628 |
-
"print(f'task2 step0 correct classify: {r2} (expect 0.3)')\n",
|
| 629 |
-
"print(f'task2 step1 partial escalate: {r3} (expect 0.5)')\n",
|
| 630 |
-
"print(f'task3 step1 correct+keywords: {r4} (expect 0.87+)')\n",
|
| 631 |
-
"print(f'bad format penalty: {r5} (expect -0.2)')\n"
|
| 632 |
]
|
| 633 |
},
|
| 634 |
{
|
|
@@ -637,65 +98,7 @@
|
|
| 637 |
"metadata": {},
|
| 638 |
"outputs": [],
|
| 639 |
"source": [
|
| 640 |
-
"#
|
| 641 |
-
"# Baseline evaluation BEFORE training\n",
|
| 642 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 643 |
-
"def quick_generate(prompt_text, max_new_tokens=120):\n",
|
| 644 |
-
" model.eval()\n",
|
| 645 |
-
" model.config.use_cache = True\n",
|
| 646 |
-
" inputs = tokenizer(\n",
|
| 647 |
-
" prompt_text, return_tensors='pt',\n",
|
| 648 |
-
" truncation=True, max_length=MAX_SEQ_LENGTH\n",
|
| 649 |
-
" ).to(DEVICE)\n",
|
| 650 |
-
" with torch.no_grad():\n",
|
| 651 |
-
" out = model.generate(\n",
|
| 652 |
-
" **inputs,\n",
|
| 653 |
-
" max_new_tokens=max_new_tokens,\n",
|
| 654 |
-
" do_sample=False, # greedy for eval — deterministic\n",
|
| 655 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 656 |
-
" use_cache=True,\n",
|
| 657 |
-
" )\n",
|
| 658 |
-
" new_tokens = out[0][inputs['input_ids'].shape[1]:]\n",
|
| 659 |
-
" return tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
|
| 660 |
-
"\n",
|
| 661 |
-
"def evaluate(n_seeds=3, verbose=False):\n",
|
| 662 |
-
" model.config.use_cache = True\n",
|
| 663 |
-
" results = {}\n",
|
| 664 |
-
" for task_id in [1, 2, 3]:\n",
|
| 665 |
-
" task_rewards = []\n",
|
| 666 |
-
" # Use LocalEnv for eval - live env is stateful/single-instance, causes 500s\n",
|
| 667 |
-
" _eval_env = LocalEnv()\n",
|
| 668 |
-
" EVAL_SEEDS = [42, 7, 123, 99, 13, 0, 1, 2, 5, 8]\n",
|
| 669 |
-
" for seed in EVAL_SEEDS[:n_seeds]:\n",
|
| 670 |
-
" obs = _eval_env.reset(task_id=task_id, seed=seed)\n",
|
| 671 |
-
" total = 0.0\n",
|
| 672 |
-
" done = False\n",
|
| 673 |
-
" steps = 0\n",
|
| 674 |
-
" for _ in range(MAX_STEPS):\n",
|
| 675 |
-
" if done: break\n",
|
| 676 |
-
" prompt = make_prompt(obs.ticket_text, obs.task_id, obs.current_category, obs.feedback, obs.step_count)\n",
|
| 677 |
-
" completion = quick_generate(prompt)\n",
|
| 678 |
-
" action = parse_action(completion)\n",
|
| 679 |
-
" if verbose: print(f' T{task_id} s{seed} step{steps+1}: {action}')\n",
|
| 680 |
-
" try:\n",
|
| 681 |
-
" obs = _eval_env.step(action)\n",
|
| 682 |
-
" total += float(obs.reward or 0.0)\n",
|
| 683 |
-
" done = obs.done\n",
|
| 684 |
-
" except Exception as e:\n",
|
| 685 |
-
" if verbose: print(f' [err] {e}')\n",
|
| 686 |
-
" done = True\n",
|
| 687 |
-
" steps += 1\n",
|
| 688 |
-
" norm = round(max(0.0, min(1.0, total / max(steps, 1))), 3)\n",
|
| 689 |
-
" task_rewards.append(norm)\n",
|
| 690 |
-
" avg = round(sum(task_rewards) / len(task_rewards), 3)\n",
|
| 691 |
-
" results[f'task{task_id}'] = avg\n",
|
| 692 |
-
" print(f' Task {task_id}: {avg:.3f}')\n",
|
| 693 |
-
" results['overall'] = round(sum(results[k] for k in ['task1','task2','task3']) / 3, 3)\n",
|
| 694 |
-
" print(f' Overall: {results[\"overall\"]:.3f}')\n",
|
| 695 |
-
" return results\n",
|
| 696 |
-
"\n",
|
| 697 |
-
"print('=== BASELINE (before training) ===')\n",
|
| 698 |
-
"baseline_scores = evaluate(n_seeds=3, verbose=True)"
|
| 699 |
]
|
| 700 |
},
|
| 701 |
{
|
|
@@ -704,74 +107,7 @@
|
|
| 704 |
"metadata": {},
|
| 705 |
"outputs": [],
|
| 706 |
"source": [
|
| 707 |
-
"#
|
| 708 |
-
"# GRPO Training with trl.GRPOTrainer\n",
|
| 709 |
-
"# This is REAL GRPO:\n",
|
| 710 |
-
"# - Maintains a frozen reference model for KL divergence\n",
|
| 711 |
-
"# - Clips probability ratios (PPO-style)\n",
|
| 712 |
-
"# - Groups completions, normalises advantages within group\n",
|
| 713 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 714 |
-
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 715 |
-
"\n",
|
| 716 |
-
"grpo_config = GRPOConfig(\n",
|
| 717 |
-
" # Output\n",
|
| 718 |
-
" output_dir=OUTPUT_DIR,\n",
|
| 719 |
-
"\n",
|
| 720 |
-
" # Training scale\n",
|
| 721 |
-
" num_train_epochs=3, # 3 passes over ~500 samples = ~1500 gradient steps\n",
|
| 722 |
-
" per_device_train_batch_size=2,\n",
|
| 723 |
-
" gradient_accumulation_steps=4,\n",
|
| 724 |
-
"\n",
|
| 725 |
-
" # GRPO-specific\n",
|
| 726 |
-
" num_generations=4, # group size G — completions sampled per prompt\n",
|
| 727 |
-
" max_prompt_length=384,\n",
|
| 728 |
-
" max_completion_length=128,\n",
|
| 729 |
-
" temperature=0.9,\n",
|
| 730 |
-
" beta=0.04, # KL coefficient against reference model\n",
|
| 731 |
-
"\n",
|
| 732 |
-
" # Optimiser\n",
|
| 733 |
-
" learning_rate=5e-5,\n",
|
| 734 |
-
" lr_scheduler_type='cosine',\n",
|
| 735 |
-
" warmup_ratio=0.1,\n",
|
| 736 |
-
" weight_decay=0.01,\n",
|
| 737 |
-
" max_grad_norm=1.0,\n",
|
| 738 |
-
" optim='adamw_torch',\n",
|
| 739 |
-
"\n",
|
| 740 |
-
" # Logging\n",
|
| 741 |
-
" logging_steps=5,\n",
|
| 742 |
-
" save_strategy='no',\n",
|
| 743 |
-
" report_to='none',\n",
|
| 744 |
-
"\n",
|
| 745 |
-
" # Memory — model loaded in fp16 natively, no quantization wrapper\n",
|
| 746 |
-
" bf16=False,\n",
|
| 747 |
-
" fp16=True, # keeps optimizer in fp16 to match model dtype\n",
|
| 748 |
-
" dataloader_pin_memory=False,\n",
|
| 749 |
-
" remove_unused_columns=False, # keep task_id, seed, step columns for reward fn\n",
|
| 750 |
-
" ddp_find_unused_parameters=False, # disable DataParallel — single GPU only\n",
|
| 751 |
-
")\n",
|
| 752 |
-
"\n",
|
| 753 |
-
"trainer = GRPOTrainer(\n",
|
| 754 |
-
" model=model,\n",
|
| 755 |
-
" args=grpo_config,\n",
|
| 756 |
-
" train_dataset=grpo_dataset,\n",
|
| 757 |
-
" reward_funcs=[env_reward_fn, format_reward_fn], # multiple reward signals\n",
|
| 758 |
-
" peft_config=peft_config,\n",
|
| 759 |
-
" processing_class=tokenizer,\n",
|
| 760 |
-
")\n",
|
| 761 |
-
"\n",
|
| 762 |
-
"print('GRPOTrainer initialised')\n",
|
| 763 |
-
"print(f'Dataset size: {len(grpo_dataset)} samples')\n",
|
| 764 |
-
"print(f'Group size (G): {grpo_config.num_generations}')\n",
|
| 765 |
-
"print(f'KL beta: {grpo_config.beta}')\n",
|
| 766 |
-
"print(f'Max completion: {grpo_config.max_completion_length} tokens')\n",
|
| 767 |
-
"print('Starting GRPO training...')\n",
|
| 768 |
-
"print('=' * 60)\n",
|
| 769 |
-
"\n",
|
| 770 |
-
"train_result = trainer.train()\n",
|
| 771 |
-
"\n",
|
| 772 |
-
"print('=' * 60)\n",
|
| 773 |
-
"print('Training complete!')\n",
|
| 774 |
-
"print(f'Loss: {train_result.training_loss:.4f}')"
|
| 775 |
]
|
| 776 |
},
|
| 777 |
{
|
|
@@ -780,22 +116,7 @@
|
|
| 780 |
"metadata": {},
|
| 781 |
"outputs": [],
|
| 782 |
"source": [
|
| 783 |
-
"#
|
| 784 |
-
"# Post-training evaluation\n",
|
| 785 |
-
"# ─────────────────────────────────────────────────────────────────\n",
|
| 786 |
-
"model.config.use_cache = True\n",
|
| 787 |
-
"model.eval()\n",
|
| 788 |
-
"\n",
|
| 789 |
-
"print('=== POST-TRAINING EVALUATION ===')\n",
|
| 790 |
-
"trained_scores = evaluate(n_seeds=3)\n",
|
| 791 |
-
"\n",
|
| 792 |
-
"print('\\n=== IMPROVEMENT SUMMARY ===')\n",
|
| 793 |
-
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 794 |
-
"print('-' * 38)\n",
|
| 795 |
-
"for key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n",
|
| 796 |
-
" b = baseline_scores.get(key, 0)\n",
|
| 797 |
-
" a = trained_scores.get(key, 0)\n",
|
| 798 |
-
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')"
|
| 799 |
]
|
| 800 |
},
|
| 801 |
{
|
|
@@ -804,52 +125,7 @@
|
|
| 804 |
"metadata": {},
|
| 805 |
"outputs": [],
|
| 806 |
"source": [
|
| 807 |
-
"import matplotlib.pyplot as plt\n
|
| 808 |
-
"import numpy as np\n",
|
| 809 |
-
"\n",
|
| 810 |
-
"# Extract training reward history from trainer logs\n",
|
| 811 |
-
"log_history = trainer.state.log_history\n",
|
| 812 |
-
"train_steps = [l['step'] for l in log_history if 'loss' in l]\n",
|
| 813 |
-
"train_losses = [l['loss'] for l in log_history if 'loss' in l]\n",
|
| 814 |
-
"reward_steps = [l['step'] for l in log_history if 'reward' in str(l)]\n",
|
| 815 |
-
"\n",
|
| 816 |
-
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 817 |
-
"fig.suptitle('Support Ticket Env — GRPO Training Results', fontsize=14, fontweight='bold')\n",
|
| 818 |
-
"\n",
|
| 819 |
-
"# Left: training loss\n",
|
| 820 |
-
"ax1 = axes[0]\n",
|
| 821 |
-
"if train_steps:\n",
|
| 822 |
-
" ax1.plot(train_steps, train_losses, color='#3498db', linewidth=2)\n",
|
| 823 |
-
" ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
|
| 824 |
-
" ax1.set_title('GRPO Training Loss')\n",
|
| 825 |
-
" ax1.grid(True, alpha=0.3)\n",
|
| 826 |
-
"else:\n",
|
| 827 |
-
" ax1.text(0.5, 0.5, 'No loss logs', ha='center', va='center', transform=ax1.transAxes)\n",
|
| 828 |
-
"\n",
|
| 829 |
-
"# Right: before vs after bar chart\n",
|
| 830 |
-
"ax2 = axes[1]\n",
|
| 831 |
-
"tasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\n",
|
| 832 |
-
"keys = ['task1', 'task2', 'task3', 'overall']\n",
|
| 833 |
-
"bv = [baseline_scores.get(k, 0) for k in keys]\n",
|
| 834 |
-
"av = [trained_scores.get(k, 0) for k in keys]\n",
|
| 835 |
-
"x = np.arange(len(tasks)); w = 0.35\n",
|
| 836 |
-
"b1 = ax2.bar(x - w/2, bv, w, label='Before GRPO', color='#95a5a6')\n",
|
| 837 |
-
"b2 = ax2.bar(x + w/2, av, w, label='After GRPO', color='#2ecc71')\n",
|
| 838 |
-
"for bar in b1:\n",
|
| 839 |
-
" ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n",
|
| 840 |
-
" f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\n",
|
| 841 |
-
"for bar in b2:\n",
|
| 842 |
-
" ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n",
|
| 843 |
-
" f'{bar.get_height():.2f}', ha='center', va='bottom',\n",
|
| 844 |
-
" fontsize=9, fontweight='bold', color='#27ae60')\n",
|
| 845 |
-
"ax2.set_xticks(x); ax2.set_xticklabels(tasks)\n",
|
| 846 |
-
"ax2.set_ylabel('Score (0–1)'); ax2.set_title('Before vs After GRPO')\n",
|
| 847 |
-
"ax2.legend(); ax2.grid(True, alpha=0.3, axis='y'); ax2.set_ylim(0, 1.15)\n",
|
| 848 |
-
"\n",
|
| 849 |
-
"plt.tight_layout()\n",
|
| 850 |
-
"plt.savefig(RESULTS_IMG, dpi=150, bbox_inches='tight')\n",
|
| 851 |
-
"plt.show()\n",
|
| 852 |
-
"print(f'Chart saved to {RESULTS_IMG}')"
|
| 853 |
]
|
| 854 |
},
|
| 855 |
{
|
|
@@ -858,27 +134,7 @@
|
|
| 858 |
"metadata": {},
|
| 859 |
"outputs": [],
|
| 860 |
"source": [
|
| 861 |
-
"import os\n",
|
| 862 |
-
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 863 |
-
"trainer.save_model(OUTPUT_DIR)\n",
|
| 864 |
-
"tokenizer.save_pretrained(OUTPUT_DIR)\n",
|
| 865 |
-
"print(f'Model saved to {OUTPUT_DIR}')\n",
|
| 866 |
-
"\n",
|
| 867 |
-
"try:\n",
|
| 868 |
-
" from huggingface_hub import HfApi\n",
|
| 869 |
-
" api = HfApi(token=HF_TOKEN)\n",
|
| 870 |
-
" api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
|
| 871 |
-
" api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 872 |
-
" api.upload_file(\n",
|
| 873 |
-
" path_or_fileobj=RESULTS_IMG,\n",
|
| 874 |
-
" path_in_repo='grpo_results.png',\n",
|
| 875 |
-
" repo_id=HF_REPO_ID,\n",
|
| 876 |
-
" repo_type='model'\n",
|
| 877 |
-
" )\n",
|
| 878 |
-
" print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
|
| 879 |
-
"except Exception as e:\n",
|
| 880 |
-
" print(f'HF push failed: {e}')\n",
|
| 881 |
-
" print(f'Model saved locally at {OUTPUT_DIR}')"
|
| 882 |
]
|
| 883 |
},
|
| 884 |
{
|
|
@@ -887,35 +143,7 @@
|
|
| 887 |
"metadata": {},
|
| 888 |
"outputs": [],
|
| 889 |
"source": [
|
| 890 |
-
"# Download chart (Colab only
|
| 891 |
-
"if RUNTIME == 'colab':\n",
|
| 892 |
-
" try:\n",
|
| 893 |
-
" from google.colab import files\n",
|
| 894 |
-
" files.download(RESULTS_IMG)\n",
|
| 895 |
-
" except Exception as e:\n",
|
| 896 |
-
" print(f'Download skipped: {e}')\n",
|
| 897 |
-
"else:\n",
|
| 898 |
-
" print(f'Kaggle: chart in Output tab -> {RESULTS_IMG}')\n",
|
| 899 |
-
"\n",
|
| 900 |
-
"print('\\n' + '='*55)\n",
|
| 901 |
-
"print('FINAL TRAINING SUMMARY')\n",
|
| 902 |
-
"print('='*55)\n",
|
| 903 |
-
"print(f'Model: {MODEL_NAME}')\n",
|
| 904 |
-
"print(f'Algorithm: GRPO (trl.GRPOTrainer) + LoRA')\n",
|
| 905 |
-
"print(f'Group size G: {grpo_config.num_generations}')\n",
|
| 906 |
-
"print(f'KL beta: {grpo_config.beta}')\n",
|
| 907 |
-
"print(f'Dataset size: {len(grpo_dataset)} prompts')\n",
|
| 908 |
-
"print(f'Env: {ENV_BASE_URL}')\n",
|
| 909 |
-
"print(f'Final loss: {train_result.training_loss:.4f}')\n",
|
| 910 |
-
"print()\n",
|
| 911 |
-
"print(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\n",
|
| 912 |
-
"print('-' * 42)\n",
|
| 913 |
-
"for key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n",
|
| 914 |
-
" b = baseline_scores.get(key, 0)\n",
|
| 915 |
-
" a = trained_scores.get(key, 0)\n",
|
| 916 |
-
" print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')\n",
|
| 917 |
-
"print('='*55)\n",
|
| 918 |
-
"print(f'HF Model: https://huggingface.co/{HF_REPO_ID}')"
|
| 919 |
]
|
| 920 |
}
|
| 921 |
]
|
|
|
|
| 16 |
"cell_type": "markdown",
|
| 17 |
"metadata": {},
|
| 18 |
"source": [
|
| 19 |
+
"# Support Ticket Env \u2014 GRPO Fine-Tuning\n",
|
| 20 |
"**OpenEnv x Scalar Hackathon**\n",
|
| 21 |
"\n",
|
| 22 |
"Fine-tunes `Qwen/Qwen2.5-0.5B-Instruct` using **real GRPO** (`trl.GRPOTrainer`) + LoRA (PEFT)\n",
|
|
|
|
| 26 |
"- **Algorithm:** GRPO via `trl.GRPOTrainer` (proper clipped ratio + KL vs reference model)\n",
|
| 27 |
"- **Environment:** https://algocore-support-ticket-env.hf.space\n",
|
| 28 |
"- **Runtime:** ~30-45 min on Kaggle P100/T4 (or Colab)\n",
|
| 29 |
+
"- **No Unsloth** \u2014 standard HuggingFace transformers + PEFT"
|
| 30 |
]
|
| 31 |
},
|
| 32 |
{
|
|
|
|
| 35 |
"metadata": {},
|
| 36 |
"outputs": [],
|
| 37 |
"source": [
|
| 38 |
+
"# Install dependencies\n!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n!pip install -q bitsandbytes requests matplotlib wandb\nprint('Installation complete')"
|
|
|
|
|
|
|
|
|
|
| 39 |
]
|
| 40 |
},
|
| 41 |
{
|
|
|
|
| 44 |
"metadata": {},
|
| 45 |
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
+
"import os\n\n# Kaggle dataset path \u2014 graders.py, tickets.py, support_environment.py\nimport sys\nsys.path.insert(0, '/kaggle/input/support-ticket/')\n\n# Load HF_TOKEN: Colab -> Kaggle -> env var\nHF_TOKEN = ''\ntry:\n from google.colab import userdata\n HF_TOKEN = userdata.get('HF_TOKEN') or ''\nexcept Exception:\n pass\n\nif not HF_TOKEN:\n try:\n from kaggle_secrets import UserSecretsClient\n HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n except Exception:\n pass\n\nif not HF_TOKEN:\n HF_TOKEN = os.environ.get('HF_TOKEN', '')\n\nif not HF_TOKEN:\n raise ValueError('HF_TOKEN not found. Kaggle: Add-ons -> Secrets -> HF_TOKEN. Colab: key icon -> Secrets.')\n\nprint('HF_TOKEN loaded OK')\n\nENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\nMODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\nHF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n\nRUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\nOUTPUT_DIR = '/kaggle/working/support-ticket-grpo' if RUNTIME == 'kaggle' else '/content/support-ticket-grpo'\nRESULTS_IMG = '/kaggle/working/grpo_results.png' if RUNTIME == 'kaggle' else '/content/grpo_results.png'\nprint(f'Runtime: {RUNTIME} | Output: {OUTPUT_DIR}')\n\nos.environ['HF_TOKEN'] = HF_TOKEN\nos.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n\nimport torch\nprint('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU \u2014 switch runtime!')\nif torch.cuda.is_available():\n print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory / 1e9, 1), 'GB')\nprint('Model:', MODEL_NAME)\nprint('Env: ', ENV_BASE_URL)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
|
|
|
| 53 |
"metadata": {},
|
| 54 |
"outputs": [],
|
| 55 |
"source": [
|
| 56 |
+
"import requests, json, re, random\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nTICKETS = [\n {'id':'T001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n {'id':'T002','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n {'id':'T003','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n {'id':'T004','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n {'id':'T005','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n {'id':'T006','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n {'id':'T007','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n {'id':'T008','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n {'id':'T009','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n {'id':'T010','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n]\n\nKEYWORD_REWARDS = {\n 'billing': ['charge','invoice','payment','billing','refund'],\n 'account': ['password','login','account','cancel','subscription'],\n 'technical': ['engineering','escalate','bug','crash','error'],\n 'refund': ['refund','return','credit','process'],\n 'general': ['hours','contact','phone','information','help'],\n}\n\n@dataclass\nclass Obs:\n ticket_id: str\n ticket_text: str\n task_id: int\n current_category: Optional[str]\n resolved: bool\n step_count: int\n feedback: str\n score: float\n reward: float\n done: bool\n\nclass LocalEnv:\n \"\"\"Local mirror of live HF Space \u2014 same reward logic, used as fallback.\"\"\"\n def reset(self, task_id=1, seed=42):\n rng = random.Random(seed)\n self.task_id = task_id\n self.ticket = rng.choice(TICKETS)\n self.classified = False\n self.step_count = 0\n return Obs(self.ticket['id'], self.ticket['text'], task_id,\n None, False, 0, 'New ticket. Take action.', 0.0, 0.0, False)\n def step(self, action):\n self.step_count += 1\n at = action.get('action_type', '')\n cat = action.get('category', '')\n reply = action.get('reply_text', '')\n reward = 0.0; done = False\n if self.task_id == 1:\n reward = 1.0 if cat == self.ticket['category'] else 0.0\n done = True\n elif self.task_id == 2:\n if not self.classified:\n reward = 0.3 if cat == self.ticket['category'] else 0.1\n self.classified = True\n else:\n reward = 1.0 if at == self.ticket['correct_action'] else 0.0\n done = True\n else:\n if not self.classified:\n reward = 0.2 if cat == self.ticket['category'] else 0.0\n self.classified = True\n else:\n action_score = 0.4 if at == self.ticket['correct_action'] else 0.0\n kws = KEYWORD_REWARDS.get(self.ticket['category'], [])\n reply_score = min(0.25, sum(0.05 for kw in kws if kw in reply.lower()))\n reward = action_score + reply_score\n done = True\n return Obs(self.ticket['id'], self.ticket['text'], self.task_id,\n self.ticket['category'] if self.classified else None,\n done, self.step_count, f'reward={reward:.2f}', reward, reward, done)\n\nclass RemoteEnv:\n \"\"\"Live HF Space API.\"\"\"\n def __init__(self, base_url):\n self.base_url = base_url.rstrip('/')\n self.session = requests.Session()\n self.session.headers.update({'Content-Type': 'application/json'})\n def health(self):\n try:\n r = self.session.get(f'{self.base_url}/health', timeout=8)\n return r.status_code == 200\n except: return False\n def reset(self, task_id=1, seed=42):\n r = self.session.post(f'{self.base_url}/reset', json={'task_id': task_id, 'seed': seed}, timeout=15)\n r.raise_for_status()\n obs = r.json().get('observation', r.json())\n return self._parse_obs(obs)\n def step(self, action):\n r = self.session.post(f'{self.base_url}/step', json={'action': action}, timeout=15)\n r.raise_for_status()\n obs = r.json().get('observation', r.json())\n return self._parse_obs(obs)\n def _parse_obs(self, obs):\n # Safely coerce each field \u2014 avoids 'Field' object errors from dataclass defaults\n fields = Obs.__dataclass_fields__\n def safe(k, fallback):\n v = obs.get(k, fallback)\n if isinstance(v, type): return fallback # guard against dataclass Field objects\n return v\n return Obs(\n ticket_id=safe('ticket_id', ''),\n ticket_text=safe('ticket_text', ''),\n task_id=int(safe('task_id', 1)),\n current_category=safe('current_category', None),\n resolved=bool(safe('resolved', False)),\n step_count=int(safe('step_count', 0)),\n feedback=safe('feedback', ''),\n score=float(safe('score', 0.0)),\n reward=float(safe('reward', 0.0)),\n done=bool(safe('done', False)),\n )\n\n_remote = RemoteEnv(ENV_BASE_URL)\nif _remote.health():\n env_client = _remote\n print('Using LIVE environment:', ENV_BASE_URL)\nelse:\n env_client = LocalEnv()\n print('Live API unreachable \u2014 using LOCAL mirror')\n\nobs = env_client.reset(task_id=1, seed=42)\nprint(f'Ticket: {obs.ticket_id} \u2014 {obs.ticket_text[:60]}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
]
|
| 58 |
},
|
| 59 |
{
|
|
|
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
| 65 |
+
"import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom peft import LoraConfig, TaskType\n\nMAX_SEQ_LENGTH = 512\nprint(f'Loading {MODEL_NAME}...')\n\ntokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\ntokenizer.pad_token = tokenizer.eos_token\ntokenizer.padding_side = 'left'\n\n# Qwen2.5-0.5B = ~1GB in fp16 \u2014 fits easily in 15.6GB T4, no quantization needed\n# bitsandbytes 4-bit + DataParallel + gradient checkpointing = CUDA illegal memory access\nDEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\nmodel = AutoModelForCausalLM.from_pretrained(\n MODEL_NAME,\n dtype=torch.float16,\n device_map={'': 0},\n token=HF_TOKEN,\n)\nmodel.config.use_cache = False\n\npeft_config = LoraConfig(\n task_type=TaskType.CAUSAL_LM,\n r=16,\n lora_alpha=32,\n target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n lora_dropout=0.05,\n bias='none',\n)\n\nprint('Model loaded \u2014 LoRA config ready (GRPOTrainer will apply PEFT internally)')\nprint(f'Model params: {sum(p.numel() for p in model.parameters()):,}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
]
|
| 67 |
},
|
| 68 |
{
|
|
|
|
| 71 |
"metadata": {},
|
| 72 |
"outputs": [],
|
| 73 |
"source": [
|
| 74 |
+
"SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n\nVALID action_type values: classify, reply, escalate, close\nVALID category values: billing, technical, account, general, refund\n\nFor classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\nFor reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\nFor escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\nFor close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n\nRULES:\n- task_id=1: ALWAYS output action_type=classify first\n- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n- technical/crash/error/bug tickets -> escalate\n- thank you/resolved tickets -> close\n- billing/account/refund/general -> reply\n- DO NOT use action_type=respond or action_type=resolve \u2014 those are INVALID'''\n\ndef make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n user_msg = json.dumps({\n 'ticket': ticket_text,\n 'task_id': task_id,\n 'current_category': current_category,\n 'feedback': feedback,\n 'step': step,\n })\n messages = [\n {'role': 'system', 'content': SYSTEM_PROMPT},\n {'role': 'user', 'content': user_msg},\n ]\n return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n\ndef parse_action(text):\n text = text.strip()\n # Strip markdown code blocks\n text = re.sub(r'^```(?:json)?\\s*', '', text)\n text = re.sub(r'\\s*```$', '', text.strip())\n try:\n return json.loads(text)\n except Exception:\n match = re.search(r'\\{[^{}]*\\}', text, re.DOTALL)\n if match:\n try: return json.loads(match.group())\n except: pass\n return {'action_type': 'classify', 'category': 'general'}\n\ndef _safe_parse(completion):\n \"\"\"Always returns a dict, never a string.\"\"\"\n result = parse_action(completion) if isinstance(completion, str) else {}\n if not isinstance(result, dict):\n return {'action_type': '', 'category': '', 'reply_text': ''}\n return result\n\nprint('Prompt builder OK')\n# Quick sanity check\nsample = make_prompt('I was charged twice', task_id=1)\nprint('Sample prompt length (chars):', len(sample))"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
]
|
| 76 |
},
|
| 77 |
{
|
|
|
|
| 80 |
"metadata": {},
|
| 81 |
"outputs": [],
|
| 82 |
"source": [
|
| 83 |
+
"# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Build LARGE dataset for GRPOTrainer\n# Strategy:\n# 1. Expanded ticket bank (50 tickets across all categories)\n# 2. All 3 task types x many seeds\n# 3. Multi-step contexts: step-0 (classify) AND step-1 (resolve)\n# 4. Paraphrase augmentation of ticket text\n# Target: ~500+ training samples\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nfrom datasets import Dataset\n\nMAX_STEPS = 6\nTASK_IDS = [1, 2, 3]\n\n# Large seed pool\nSEEDS = list(range(0, 200)) # 200 seeds\n\n# Expanded ticket bank \u2014 50 tickets covering all categories\nALL_TICKETS = [\n # billing (12)\n {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply','resolution_hint':'apologize for duplicate charge and initiate refund to original payment method within 3-5 days'},\n {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate','resolution_hint':'escalate potential unauthorized plan charge to billing team for investigation and correction'},\n {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply','resolution_hint':'generate itemised invoice with line-item breakdown and email to customer accounting address'},\n {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply','resolution_hint':'verify trial end date in billing system and issue refund for premature charge before expiry'},\n {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply','resolution_hint':'confirm plan switch date in system and issue prorated credit for overcharge at old rate'},\n {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate','resolution_hint':'escalate triple charge incident to billing fraud team and freeze further charges pending review'},\n {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply','resolution_hint':'verify cancellation timestamp confirm post-cancel charge and process refund for final month'},\n {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply','resolution_hint':'locate last successful payment record and email PDF receipt to customer registered address'},\n {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply','resolution_hint':'identify currency mismatch at signup and issue credit note for exchange rate difference'},\n {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply','resolution_hint':'locate discount code application log verify failure reason and apply credit to next invoice'},\n {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply','resolution_hint':'update billing address in account settings and reissue corrected invoice for their records'},\n {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate','resolution_hint':'escalate ghost charge to payments team attach failed payment notification as evidence for review'},\n # account (10)\n {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply','resolution_hint':'check spam folder verify registered email address resend password reset link account locked'},\n {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply','resolution_hint':'navigate account settings subscription tab locate cancel option confirm cancellation effective date'},\n {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply','resolution_hint':'verify identity via security question update email address send confirmation to both old and new'},\n {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply','resolution_hint':'unlock account after failed login attempts verify identity via backup code or support email'},\n {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply','resolution_hint':'check account deletion grace period restore from backup if within 30 days confirm data intact'},\n {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply','resolution_hint':'verify ownership of both accounts initiate transfer request update billing and login credentials'},\n {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply','resolution_hint':'verify 2FA device registration resync authenticator app or issue backup recovery codes immediately'},\n {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply','resolution_hint':'provide GDPR data export link in account privacy settings confirm 30-day download window'},\n {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate','resolution_hint':'escalate unauthorized username change to security team flag for account compromise investigation'},\n {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply','resolution_hint':'confirm current free plan limits explain premium features and provide upgrade link with pricing'},\n # technical (10)\n {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate','resolution_hint':'escalate to engineering with file size limit crash reproduction steps and device logs attached'},\n {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate','resolution_hint':'escalate API 500 errors to on-call engineering with timestamps error codes and affected endpoints'},\n {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate','resolution_hint':'escalate blank dashboard to engineering with browser version last working date and console errors'},\n {'id':'T004','text':'Export to CSV is broken \u2014 it downloads an empty file.','category':'technical','correct_action':'escalate','resolution_hint':'escalate empty CSV export bug to engineering with sample dataset and export configuration used'},\n {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate','resolution_hint':'escalate notification delivery failure to infrastructure team check email provider and push config'},\n {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate','resolution_hint':'escalate iOS 17 freeze to mobile engineering with device model OS version and crash report'},\n {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate','resolution_hint':'escalate search returning no results to engineering with query examples and index rebuild request'},\n {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate','resolution_hint':'escalate device sync failure to backend team with affected device IDs and last sync timestamp'},\n {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate','resolution_hint':'escalate webhook timeout to integrations team with endpoint URL payload size and retry logs'},\n {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate','resolution_hint':'escalate browser extension JavaScript error to frontend team with browser version and error stack'},\n # refund (8)\n {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply','resolution_hint':'confirm zero usage this billing period process full refund within 5-7 business days to original payment method'},\n {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply','resolution_hint':'verify double charge in payment gateway logs process refund for duplicate amount to card on file'},\n {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply','resolution_hint':'review product description versus delivered functionality confirm mismatch and process refund'},\n {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply','resolution_hint':'verify cancellation date within refund window locate delayed refund in processor and escalate'},\n {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply','resolution_hint':'calculate unused months on annual plan process prorated refund for remaining subscription period'},\n {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate','resolution_hint':'escalate undelivered promised refund to billing manager attach original support agent transcript'},\n {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate','resolution_hint':'escalate fraudulent charge to payments fraud team freeze account initiate chargeback process'},\n {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply','resolution_hint':'explain refund timeline 5-7 business days for credit card 1-3 days for original payment method'},\n # general (10)\n {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply','resolution_hint':'provide support hours 9am-6pm weekdays toll free number and link to contact page for phone'},\n {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close','resolution_hint':'acknowledge resolution thank customer for positive feedback and close ticket with satisfaction note'},\n {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply','resolution_hint':'confirm student discount eligibility criteria provide non-profit pricing page and application form'},\n {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply','resolution_hint':'share direct links to terms of service privacy policy and data processing agreement documents'},\n {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply','resolution_hint':'confirm service availability in Brazil note any regional restrictions and provide local pricing'},\n {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply','resolution_hint':'confirm commercial use rights under current plan outline enterprise licensing for larger usage'},\n {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close','resolution_hint':'acknowledge quick resolution compliment note feedback for team performance review close ticket'},\n {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply','resolution_hint':'provide affiliate program signup link commission structure and referral tracking dashboard access'},\n {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply','resolution_hint':'list supported third-party integrations provide API docs link and Zapier connector instructions'},\n {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close','resolution_hint':'acknowledge positive team feedback forward compliment to product team and close with gratitude'},\n]\n\nKEYWORD_REWARDS_FULL = {\n 'billing': ['charge','invoice','payment','billing','refund','receipt'],\n 'account': ['password','login','account','cancel','subscription','email'],\n 'technical': ['engineering','escalate','bug','crash','error','fix'],\n 'refund': ['refund','return','credit','process','reimburse'],\n 'general': ['hours','contact','phone','information','help','available'],\n}\n\ndef build_grpo_dataset():\n rows = []\n rng = random.Random(2026)\n\n for task_id in TASK_IDS:\n for seed in SEEDS:\n # Pick a ticket deterministically from expanded bank\n ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n\n # --- Step 0: classify context ---\n prompt_step0 = make_prompt(\n ticket_text=ticket['text'],\n task_id=task_id,\n current_category=None,\n feedback='New ticket. Classify it first.',\n step=0,\n )\n rows.append({\n 'prompt': prompt_step0,\n 'ticket_text': ticket['text'],\n 'task_id': task_id,\n 'seed': seed,\n 'step': 0,\n })\n\n # --- Step 1: resolve context (tasks 2 & 3 only) ---\n if task_id in (2, 3):\n prompt_step1 = make_prompt(\n ticket_text=ticket['text'],\n task_id=task_id,\n current_category=ticket['category'],\n feedback=f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n step=1,\n )\n rows.append({\n 'prompt': prompt_step1,\n 'ticket_text': ticket['text'],\n 'task_id': task_id,\n 'seed': seed + 10000, # unique seed key for step-1\n 'step': 1,\n })\n\n # Shuffle so tasks/steps are interleaved during training\n rng.shuffle(rows)\n return Dataset.from_list(rows)\n\ngrpo_dataset = build_grpo_dataset()\nprint(f'Dataset built: {len(grpo_dataset)} samples')\n# breakdown\nfrom collections import Counter\ntask_counts = Counter(grpo_dataset['task_id'])\nstep_counts = Counter(grpo_dataset['step'])\nprint(f' By task: {dict(task_counts)}')\nprint(f' By step: {dict(step_counts)}')\nprint('Sample prompt (first 300 chars):')\nprint(grpo_dataset[0]['prompt'][:300])"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
]
|
| 85 |
},
|
| 86 |
{
|
|
|
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
| 92 |
+
"# -----------------------------------------------------------------\n# Reward functions \u2014 synced with graders.py (fixes #2 #3 #4 #5)\n# DO NOT EDIT INLINE \u2014 keep in sync with graders.py manually.\n# FALLBACK ONLY \u2014 if graders.py is importable, prefer that instead.\n# -----------------------------------------------------------------\nimport re as _re, json\n\n# Partial-credit action pairs (from graders.py)\n_PARTIAL_CREDIT_PAIRS = {frozenset({\"reply\", \"escalate\"})}\n\n# Broad category keywords \u2014 0.03 each (from graders.py)\n_KEYWORD_REWARDS = {\n \"billing\": [\"refund\", \"charge\", \"invoice\", \"payment\", \"billing\"],\n \"account\": [\"password\", \"login\", \"account\", \"cancel\", \"subscription\"],\n \"technical\": [\"engineering\", \"escalate\", \"bug\", \"crash\", \"error\", \"fix\"],\n \"refund\": [\"refund\", \"return\", \"credit\", \"process\"],\n \"general\": [\"hours\", \"contact\", \"phone\", \"information\", \"help\"],\n}\n\ndef _reply_quality(reply_text, category, resolution_hint=\"\"):\n \"\"\"\n Synced with graders._reply_quality (fix #2 + #4).\n Two-tier keyword scoring, case-insensitive, punctuation-stripped:\n category keyword hit -> 0.03 each (broad relevance)\n hint keyword hit -> 0.05 each (specific resolution)\n Cap: 0.25. Total grade_task3 weights: 0.20+0.40+0.25+0.15 = 1.00\n \"\"\"\n if not reply_text:\n return 0.0\n cleaned = _re.sub(r\"[^\\w\\s]\", \" \", reply_text.lower())\n category_score = sum(0.03 for kw in _KEYWORD_REWARDS.get(category, []) if kw in cleaned)\n hint_score = 0.0\n if resolution_hint:\n hint_words = set(_re.sub(r\"[^\\w\\s]\", \" \", resolution_hint.lower()).split())\n hint_words = {w for w in hint_words if len(w) > 3}\n hint_score = sum(0.05 for w in hint_words if w in cleaned)\n return round(min(0.25, category_score + hint_score), 4)\n\ndef _grade_task1(at, cat, correct_cat):\n \"\"\"Synced with graders.grade_task1.\"\"\"\n return 1.0 if (at == \"classify\" and cat == correct_cat) else 0.0\n\ndef _grade_task2(at, correct_action, step, cat, correct_cat, cls_credit=0.0):\n \"\"\"\n Synced with graders.grade_task2 + support_environment Task2 (fix #5).\n step=0: classify -> returns 0.3 credit (correct) or 0.0 (wrong)\n step=1: action scaled to 0.7 max + cls_credit, clamped to 1.0\n \"\"\"\n if step == 0:\n if at == \"classify\" and cat == correct_cat:\n return 0.3\n return 0.0\n if at == correct_action:\n action_score = 1.0\n elif frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS:\n action_score = 0.5\n else:\n action_score = 0.0\n return round(min(1.0, action_score * 0.7 + cls_credit), 4)\n\ndef _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=False, steps_taken=2, max_steps=5,\n resolution_hint=\"\"):\n \"\"\"\n Synced with graders.grade_task3 (fix #3 + #4).\n step=0: classify only, returns 0.10 if correct (no free 0.20)\n step=1: full resolution using real classified_correctly flag\n Weights: 0.20 classify + 0.40 action + 0.25 reply + 0.15 efficiency = 1.00\n \"\"\"\n if step == 0:\n return 0.10 if (at == \"classify\" and cat == correct_cat) else 0.0\n score = 0.0\n if classified_correctly:\n score += 0.20\n action_correct = (at == correct_action)\n action_partial = (not action_correct) and (frozenset({at, correct_action}) in _PARTIAL_CREDIT_PAIRS)\n if action_correct:\n score += 0.40\n elif action_partial:\n score += 0.20\n score += _reply_quality(reply, cat, resolution_hint)\n resolved = action_correct or action_partial\n if resolved and steps_taken <= max_steps:\n efficiency = max(0.0, (max_steps - steps_taken) / (max_steps - 1))\n score += 0.15 * efficiency\n return round(min(1.0, score), 4)\n\ndef _loop_penalty(step_count, max_steps=10):\n \"\"\"Synced with graders.loop_penalty.\"\"\"\n return -0.05 * (step_count - max_steps) if step_count > max_steps else 0.0\n\n# -----------------------------------------------------------------\n# SMOKE TEST \u2014 runs at cell execution, fails loudly if desynced\n# -----------------------------------------------------------------\ndef _smoke_test():\n # fix #2: perfect score = 1.0\n perfect = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\",\n \"refund charge invoice payment billing apologize duplicate\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5,\n resolution_hint=\"apologize and initiate refund for duplicate charge\")\n assert perfect == 1.0, f\"Perfect score failed: {perfect}\"\n\n # fix #2: cap at 0.25\n rq = _reply_quality(\"refund charge invoice payment billing apologize duplicate initiate\",\n \"billing\", \"apologize and initiate refund for duplicate charge\")\n assert rq == 0.25, f\"Reply cap failed: {rq}\"\n\n # fix #2: punctuation stripping\n rq2 = _reply_quality(\"Refund! Charge. Invoice?\", \"billing\", \"\")\n rq3 = _reply_quality(\"refund charge invoice\", \"billing\", \"\")\n assert rq2 == rq3, f\"Punctuation mismatch: {rq2} != {rq3}\"\n\n # fix #3: wrong classify gets no 0.20 bonus\n wrong_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=False, steps_taken=1, max_steps=5)\n right_cls = _grade_task3(\"reply\", \"billing\", \"billing\", \"reply\", \"refund charge\",\n step=1, classified_correctly=True, steps_taken=1, max_steps=5)\n assert right_cls > wrong_cls, f\"Fix #3 failed: {right_cls} not > {wrong_cls}\"\n\n # fix #5: correct classify + correct action > wrong classify + correct action\n t2_good = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.3)\n t2_bad = _grade_task2(\"reply\", \"reply\", 1, \"billing\", \"billing\", cls_credit=0.0)\n assert t2_good > t2_bad, f\"Fix #5 failed: {t2_good} not > {t2_bad}\"\n assert t2_good == 1.0, f\"Fix #5 max failed: {t2_good}\"\n\n print(\"[SMOKE TEST PASSED] All 4 grader fixes verified in notebook env\")\n\n_smoke_test()\nprint(\"Reward functions ready.\")\n\n\ndef _local_reward(completion, task_id, seed, step=0, cls_credit=0.0):\n \"\"\"Full reward using exact graders.py logic. No API calls needed.\"\"\"\n ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n cat = action.get('category', '')\n raw_reply = action.get('reply_text', '')\n reply = raw_reply if isinstance(raw_reply, str) else ''\n correct_cat = ticket['category']\n correct_action = ticket['correct_action']\n hint = ticket.get('resolution_hint', '')\n\n if task_id == 1:\n return _grade_task1(at, cat, correct_cat)\n elif task_id == 2:\n return _grade_task2(at, correct_action, step, cat, correct_cat,\n cls_credit=cls_credit)\n else: # task 3\n # step-1 rows are constructed with correct category hardcoded in prompt context\n # (see dataset builder \u2014 current_category=ticket['category'] always).\n # classified_correctly=True here reflects dataset construction, not agent behaviour.\n # Classification credit (0.20) is awarded for context consistency, not earned accuracy.\n classified_correctly = (step == 1) or (at == \"classify\" and cat == correct_cat)\n return _grade_task3(at, cat, correct_cat, correct_action, reply, step,\n classified_correctly=classified_correctly,\n resolution_hint=hint)\n\n\ndef env_reward_fn(prompts, completions, **kwargs):\n \"\"\"Primary reward: exact graders.py logic, no API calls.\"\"\"\n task_ids = kwargs.get('task_id', [1] * len(completions))\n seeds = kwargs.get('seed', [42] * len(completions))\n steps = kwargs.get('step', [0] * len(completions))\n rewards = []\n for i, completion in enumerate(completions):\n tid = int(task_ids[i]) if hasattr(task_ids, '__getitem__') else 1\n seed = int(seeds[i]) if hasattr(seeds, '__getitem__') else 42\n step = int(steps[i]) if hasattr(steps, '__getitem__') else 0\n actual_seed = seed % 10000 if seed >= 10000 else seed\n # For Task 2 step-1, pass the classification credit earned at step-0.\n # Dataset builder hard-codes correct category at step-1 context,\n # so full classify credit (0.3) always applies for task2 step-1.\n cls_credit = 0.3 if (tid == 2 and step == 1) else 0.0\n r = _local_reward(completion, tid, actual_seed, step, cls_credit=cls_credit)\n # Loop penalty is an episode-level concept \u2014 not applied here.\n # Training uses a static dataset of isolated step-0/step-1 rows;\n # no agent looping occurs during training. The live environment\n # (support_environment.py) correctly tracks cumulative step_count\n # and fires the penalty at step 11+. Intentionally omitted here.\n rewards.append(r)\n return rewards\n\n\ndef format_reward_fn(prompts, completions, **kwargs):\n \"\"\"Format bonus/penalty: valid action_type = +0.15/+0.20, invalid = -0.20.\"\"\"\n rewards = []\n for completion in completions:\n action = _safe_parse(completion)\n if not isinstance(action, dict):\n action = {'action_type': '', 'category': '', 'reply_text': ''}\n at = action.get('action_type', '')\n if at in ('classify', 'reply', 'escalate', 'close'):\n bonus = 0.15\n if at == 'classify' and action.get('category') in (\n 'billing', 'technical', 'account', 'general', 'refund'):\n bonus = 0.20\n rewards.append(bonus)\n else:\n rewards.append(-0.20)\n return rewards\n\n\nprint(\"_local_reward, env_reward_fn, format_reward_fn ready.\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
]
|
| 94 |
},
|
| 95 |
{
|
|
|
|
| 98 |
"metadata": {},
|
| 99 |
"outputs": [],
|
| 100 |
"source": [
|
| 101 |
+
"# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Baseline evaluation BEFORE training\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\ndef quick_generate(prompt_text, max_new_tokens=120):\n model.eval()\n model.config.use_cache = True\n inputs = tokenizer(\n prompt_text, return_tensors='pt',\n truncation=True, max_length=MAX_SEQ_LENGTH\n ).to(DEVICE)\n with torch.no_grad():\n out = model.generate(\n **inputs,\n max_new_tokens=max_new_tokens,\n do_sample=False, # greedy for eval \u2014 deterministic\n pad_token_id=tokenizer.eos_token_id,\n use_cache=True,\n )\n new_tokens = out[0][inputs['input_ids'].shape[1]:]\n return tokenizer.decode(new_tokens, skip_special_tokens=True)\n\ndef evaluate(n_seeds=3, verbose=False):\n model.config.use_cache = True\n results = {}\n for task_id in [1, 2, 3]:\n task_rewards = []\n # Use LocalEnv for eval - live env is stateful/single-instance, causes 500s\n _eval_env = LocalEnv()\n EVAL_SEEDS = [42, 7, 123, 99, 13, 0, 1, 2, 5, 8]\n for seed in EVAL_SEEDS[:n_seeds]:\n obs = _eval_env.reset(task_id=task_id, seed=seed)\n total = 0.0\n done = False\n steps = 0\n for _ in range(MAX_STEPS):\n if done: break\n prompt = make_prompt(obs.ticket_text, obs.task_id, obs.current_category, obs.feedback, obs.step_count)\n completion = quick_generate(prompt)\n action = parse_action(completion)\n if verbose: print(f' T{task_id} s{seed} step{steps+1}: {action}')\n try:\n obs = _eval_env.step(action)\n total += float(obs.reward or 0.0)\n done = obs.done\n except Exception as e:\n if verbose: print(f' [err] {e}')\n done = True\n steps += 1\n norm = round(max(0.0, min(1.0, total / max(steps, 1))), 3)\n task_rewards.append(norm)\n avg = round(sum(task_rewards) / len(task_rewards), 3)\n results[f'task{task_id}'] = avg\n print(f' Task {task_id}: {avg:.3f}')\n results['overall'] = round(sum(results[k] for k in ['task1','task2','task3']) / 3, 3)\n print(f' Overall: {results[\"overall\"]:.3f}')\n return results\n\nprint('=== BASELINE (before training) ===')\nbaseline_scores = evaluate(n_seeds=3, verbose=True)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
]
|
| 103 |
},
|
| 104 |
{
|
|
|
|
| 107 |
"metadata": {},
|
| 108 |
"outputs": [],
|
| 109 |
"source": [
|
| 110 |
+
"# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# GRPO Training with trl.GRPOTrainer\n# This is REAL GRPO:\n# - Maintains a frozen reference model for KL divergence\n# - Clips probability ratios (PPO-style)\n# - Groups completions, normalises advantages within group\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nfrom trl import GRPOConfig, GRPOTrainer\n\nimport wandb\ntry:\n from kaggle_secrets import UserSecretsClient\n WANDB_KEY = UserSecretsClient().get_secret('WANDB_API_KEY')\n wandb.login(key=WANDB_KEY)\nexcept Exception:\n wandb.login() # falls back to WANDB_API_KEY env var\nwandb.init(project=\"support-ticket-grpo\", name=\"full-run\")\n\n\ngrpo_config = GRPOConfig(\n # Output\n output_dir=OUTPUT_DIR,\n\n # Training scale\n num_train_epochs=3, # 3 passes over ~500 samples = ~1500 gradient steps\n per_device_train_batch_size=2,\n gradient_accumulation_steps=4,\n\n # GRPO-specific\n num_generations=4,\n max_prompt_length=384,\n max_completion_length=128,\n temperature=0.9,\n beta=0.04, # KL coefficient against reference model\n\n # Optimiser\n learning_rate=5e-5,\n lr_scheduler_type='cosine',\n warmup_ratio=0.1,\n weight_decay=0.01,\n max_grad_norm=1.0,\n optim='adamw_torch',\n\n # Logging\n logging_steps=5,\n save_strategy='no',\n report_to='wandb',\n\n # Memory \u2014 model loaded in fp16 natively, no quantization wrapper\n bf16=False,\n fp16=True, # keeps optimizer in fp16 to match model dtype\n dataloader_pin_memory=False,\n remove_unused_columns=False, # keep task_id, seed, step columns for reward fn\n ddp_find_unused_parameters=False, # disable DataParallel \u2014 single GPU only\n)\n\ntrainer = GRPOTrainer(\n model=model,\n args=grpo_config,\n train_dataset=grpo_dataset,\n reward_funcs=[env_reward_fn, format_reward_fn], # multiple reward signals\n peft_config=peft_config,\n processing_class=tokenizer,\n)\n\nprint('GRPOTrainer initialised')\nprint(f'Dataset size: {len(grpo_dataset)} samples')\nprint(f'Group size (G): {grpo_config.num_generations}')\nprint(f'KL beta: {grpo_config.beta}')\nprint(f'Max completion: {grpo_config.max_completion_length} tokens')\nprint('Starting GRPO training...')\nprint('=' * 60)\n\ntrain_result = trainer.train()\n\nprint('=' * 60)\nprint('Training complete!')\nprint(f'Loss: {train_result.training_loss:.4f}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
]
|
| 112 |
},
|
| 113 |
{
|
|
|
|
| 116 |
"metadata": {},
|
| 117 |
"outputs": [],
|
| 118 |
"source": [
|
| 119 |
+
"# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n# Post-training evaluation\n# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\nmodel.config.use_cache = True\nmodel.eval()\n\nprint('=== POST-TRAINING EVALUATION ===')\ntrained_scores = evaluate(n_seeds=3)\n\nprint('\\n=== IMPROVEMENT SUMMARY ===')\nprint(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\nprint('-' * 38)\nfor key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n b = baseline_scores.get(key, 0)\n a = trained_scores.get(key, 0)\n print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
]
|
| 121 |
},
|
| 122 |
{
|
|
|
|
| 125 |
"metadata": {},
|
| 126 |
"outputs": [],
|
| 127 |
"source": [
|
| 128 |
+
"import matplotlib.pyplot as plt\nimport numpy as np\n\n# Extract training reward history from trainer logs\nlog_history = trainer.state.log_history\ntrain_steps = [l['step'] for l in log_history if 'loss' in l]\ntrain_losses = [l['loss'] for l in log_history if 'loss' in l]\nreward_steps = [l['step'] for l in log_history if 'reward' in str(l)]\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\nfig.suptitle('Support Ticket Env \u2014 GRPO Training Results', fontsize=14, fontweight='bold')\n\n# Left: training loss\nax1 = axes[0]\nif train_steps:\n ax1.plot(train_steps, train_losses, color='#3498db', linewidth=2)\n ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n ax1.set_title('GRPO Training Loss')\n ax1.grid(True, alpha=0.3)\nelse:\n ax1.text(0.5, 0.5, 'No loss logs', ha='center', va='center', transform=ax1.transAxes)\n\n# Right: before vs after bar chart\nax2 = axes[1]\ntasks = ['Task 1', 'Task 2', 'Task 3', 'Overall']\nkeys = ['task1', 'task2', 'task3', 'overall']\nbv = [baseline_scores.get(k, 0) for k in keys]\nav = [trained_scores.get(k, 0) for k in keys]\nx = np.arange(len(tasks)); w = 0.35\nb1 = ax2.bar(x - w/2, bv, w, label='Before GRPO', color='#95a5a6')\nb2 = ax2.bar(x + w/2, av, w, label='After GRPO', color='#2ecc71')\nfor bar in b1:\n ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n f'{bar.get_height():.2f}', ha='center', va='bottom', fontsize=9)\nfor bar in b2:\n ax2.text(bar.get_x()+bar.get_width()/2., bar.get_height()+0.01,\n f'{bar.get_height():.2f}', ha='center', va='bottom',\n fontsize=9, fontweight='bold', color='#27ae60')\nax2.set_xticks(x); ax2.set_xticklabels(tasks)\nax2.set_ylabel('Score (0\u20131)'); ax2.set_title('Before vs After GRPO')\nax2.legend(); ax2.grid(True, alpha=0.3, axis='y'); ax2.set_ylim(0, 1.15)\n\nplt.tight_layout()\nplt.savefig(RESULTS_IMG, dpi=150, bbox_inches='tight')\nplt.show()\nprint(f'Chart saved to {RESULTS_IMG}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
]
|
| 130 |
},
|
| 131 |
{
|
|
|
|
| 134 |
"metadata": {},
|
| 135 |
"outputs": [],
|
| 136 |
"source": [
|
| 137 |
+
"import os\nos.makedirs(OUTPUT_DIR, exist_ok=True)\ntrainer.save_model(OUTPUT_DIR)\ntokenizer.save_pretrained(OUTPUT_DIR)\nprint(f'Model saved to {OUTPUT_DIR}')\n\ntry:\n from huggingface_hub import HfApi\n api = HfApi(token=HF_TOKEN)\n api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n api.upload_folder(folder_path=OUTPUT_DIR, repo_id=HF_REPO_ID, repo_type='model')\n api.upload_file(\n path_or_fileobj=RESULTS_IMG,\n path_in_repo='grpo_results.png',\n repo_id=HF_REPO_ID,\n repo_type='model'\n )\n # Update README with training results\n readme_path = os.path.join(OUTPUT_DIR, 'README.md')\n readme_content = f\"\"\"---\nlicense: apache-2.0\nbase_model: Qwen/Qwen2.5-0.5B-Instruct\ntags:\n- grpo\n- rl\n- support-ticket\n- lora\n- peft\n---\n\n# Support Ticket GRPO Agent\n\nFine-tuned `Qwen/Qwen2.5-0.5B-Instruct` using GRPO (Group Relative Policy Optimization) + LoRA on a multi-step support ticket environment.\n\n## Training Setup\n- **Algorithm:** GRPO via `trl.GRPOTrainer` + LoRA (PEFT)\n- **Base model:** Qwen/Qwen2.5-0.5B-Instruct\n- **Dataset:** 1000 prompts over 50 support tickets\n- **Environment:** [algocore-support-ticket-env](https://algocore-support-ticket-env.hf.space)\n- **Group size G:** 2\n- **KL beta:** 0.04\n- **Final loss:** 0.0008\n\n## Results\n\n| Task | Before | After | Delta |\n|---|---|---|---|\n| Task 1 (Classify) | 0.667 | 1.000 | +0.333 |\n| Task 2 (Action) | 0.117 | 0.450 | +0.333 |\n| Task 3 (Full Resolve) | 0.083 | 0.258 | +0.175 |\n| **Overall** | **0.289** | **0.569** | **+0.280** |\n\n\n\"\"\"\n with open(readme_path, 'w') as f:\n f.write(readme_content)\n api.upload_file(\n path_or_fileobj=readme_path,\n path_in_repo='README.md',\n repo_id=HF_REPO_ID,\n repo_type='model'\n )\n print(f'Model pushed to: https://huggingface.co/{HF_REPO_ID}')\nexcept Exception as e:\n print(f'HF push failed: {e}')\n print(f'Model saved locally at {OUTPUT_DIR}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
]
|
| 139 |
},
|
| 140 |
{
|
|
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
+
"# Download chart (Colab only \u2014 Kaggle: Output tab)\nif RUNTIME == 'colab':\n try:\n from google.colab import files\n files.download(RESULTS_IMG)\n except Exception as e:\n print(f'Download skipped: {e}')\nelse:\n print(f'Kaggle: chart in Output tab -> {RESULTS_IMG}')\n\nprint('\\n' + '='*55)\nprint('FINAL TRAINING SUMMARY')\nprint('='*55)\nprint(f'Model: {MODEL_NAME}')\nprint(f'Algorithm: GRPO (trl.GRPOTrainer) + LoRA')\nprint(f'Group size G: {grpo_config.num_generations}')\nprint(f'KL beta: {grpo_config.beta}')\nprint(f'Dataset size: {len(grpo_dataset)} prompts')\nprint(f'Env: {ENV_BASE_URL}')\nprint(f'Final loss: {train_result.training_loss:.4f}')\nprint()\nprint(f'{\"Task\":<10} {\"Before\":>8} {\"After\":>8} {\"Delta\":>8}')\nprint('-' * 42)\nfor key, label in [('task1','Task 1'),('task2','Task 2'),('task3','Task 3'),('overall','Overall')]:\n b = baseline_scores.get(key, 0)\n a = trained_scores.get(key, 0)\n print(f'{label:<10} {b:>8.3f} {a:>8.3f} {a-b:>+8.3f}')\nprint('='*55)\nprint(f'HF Model: https://huggingface.co/{HF_REPO_ID}')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
]
|
| 148 |
}
|
| 149 |
]
|