Spaces:
Sleeping
Sleeping
File size: 22,711 Bytes
cf0d796 2e81e98 cf0d796 2e81e98 cf0d796 2e81e98 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 | {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Support Ticket Env β SFT Pre-Training\n",
"**OpenEnv x Scalar Hackathon**\n",
"\n",
"Step 1 of 2-stage training: **Supervised Fine-Tuning** on gold-label examples.\n",
"This teaches the model correct JSON format + task logic before GRPO optimization.\n",
"\n",
"- **Model:** Qwen/Qwen2.5-0.5B-Instruct\n",
"- **Algorithm:** SFT via `trl.SFTTrainer`\n",
"- **Dataset:** 1000 gold-label (prompt, completion) pairs\n",
"- **Runtime:** ~15-20 min on Kaggle T4\n",
"- **Output:** `/kaggle/working/sft-model` β used as base for GRPO"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n",
"!pip install -q requests matplotlib\n",
"print('Installation complete')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, json, random\n",
"\n",
"HF_TOKEN = ''\n",
"try:\n",
" from google.colab import userdata\n",
" HF_TOKEN = userdata.get('HF_TOKEN') or ''\n",
"except: pass\n",
"if not HF_TOKEN:\n",
" try:\n",
" from kaggle_secrets import UserSecretsClient\n",
" HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n",
" except: pass\n",
"if not HF_TOKEN:\n",
" HF_TOKEN = os.environ.get('HF_TOKEN', '')\n",
"if not HF_TOKEN:\n",
" raise ValueError('HF_TOKEN not found.')\n",
"\n",
"print('HF_TOKEN loaded OK')\n",
"\n",
"MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
"RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
"SFT_OUT = '/kaggle/working/sft-model' if RUNTIME == 'kaggle' else '/content/sft-model'\n",
"RESULTS_IMG = '/kaggle/working/sft_results.png' if RUNTIME == 'kaggle' else '/content/sft_results.png'\n",
"HF_REPO_ID = 'AlgoCore/support-ticket-sft-model'\n",
"\n",
"os.environ['HF_TOKEN'] = HF_TOKEN\n",
"os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n",
"\n",
"import torch\n",
"DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU')\n",
"print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9,1), 'GB')\n",
"print('Runtime:', RUNTIME, '| SFT output:', SFT_OUT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ββ Gold-label ticket bank (50 tickets, all categories) ββββββββββββββββββ\n",
"ALL_TICKETS = [\n",
" # billing (12)\n",
" {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
" {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
" {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
" {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n",
" {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n",
" {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n",
" {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n",
" {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n",
" {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n",
" {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n",
" {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n",
" {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n",
" # account (10)\n",
" {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
" {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
" {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n",
" {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n",
" {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n",
" {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n",
" {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n",
" {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n",
" {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n",
" {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n",
" # technical (10)\n",
" {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T004','text':'Export to CSV is broken β it downloads an empty file.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n",
" {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n",
" # refund (8)\n",
" {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
" {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n",
" {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n",
" {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n",
" {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n",
" {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n",
" {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n",
" {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n",
" # general (10)\n",
" {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
" {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
" {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n",
" {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n",
" {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n",
" {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n",
" {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n",
" {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n",
" {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n",
" {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n",
"]\n",
"\n",
"# Gold reply templates per category\n",
"GOLD_REPLIES = {\n",
" 'billing': 'Thank you for reaching out about this billing issue. I can see the charge on your account. Our billing team will review your invoice and process any corrections. You will receive an updated receipt via email within 24 hours.',\n",
" 'account': 'Thank you for contacting us about your account. I understand how frustrating this can be. Please try resetting your password using the link sent to your registered email. If the issue persists, our account team will assist you directly.',\n",
" 'technical': 'I am escalating this technical issue to our engineering team immediately. They will investigate the root cause and provide a fix. You will receive a follow-up within 2 business hours.',\n",
" 'refund': 'I understand your refund request. Our refund policy allows for full refunds within 30 days. I am processing your refund now and it will appear on your credit card within 5-7 business days.',\n",
" 'general': 'Thank you for reaching out! Our support team is available Monday to Friday, 9am-6pm. You can also contact us via phone or email. We are happy to help with any questions you have.',\n",
" 'close': 'Thank you for letting us know the issue has been resolved. We are glad we could help! Please do not hesitate to contact us if you need anything else.',\n",
"}\n",
"\n",
"SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n",
"\n",
"VALID action_type values: classify, reply, escalate, close\n",
"VALID category values: billing, technical, account, general, refund\n",
"\n",
"For classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\n",
"For reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\n",
"For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n",
"For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n",
"\n",
"RULES:\n",
"- task_id=1: ALWAYS output action_type=classify first\n",
"- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n",
"- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n",
"- technical/crash/error/bug tickets -> escalate\n",
"- thank you/resolved tickets -> close\n",
"- billing/account/refund/general -> reply\n",
"- DO NOT use action_type=respond or action_type=resolve β those are INVALID'''\n",
"\n",
"def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n",
" user_msg = json.dumps({\n",
" 'ticket': ticket_text,\n",
" 'task_id': task_id,\n",
" 'current_category': current_category,\n",
" 'feedback': feedback,\n",
" 'step': step,\n",
" })\n",
" return [\n",
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
" {'role': 'user', 'content': user_msg},\n",
" ]\n",
"\n",
"def gold_completion(ticket, task_id, step):\n",
" \"\"\"Generate the perfect gold-label completion for a given ticket + step.\"\"\"\n",
" cat = ticket['category']\n",
" action = ticket['correct_action']\n",
"\n",
" if task_id == 1:\n",
" # Always classify\n",
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
"\n",
" elif task_id == 2:\n",
" if step == 0:\n",
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
" else:\n",
" if action == 'escalate':\n",
" return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
" elif action == 'close':\n",
" return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
" else:\n",
" return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
"\n",
" else: # task 3\n",
" if step == 0:\n",
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
" else:\n",
" if action == 'escalate':\n",
" return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
" elif action == 'close':\n",
" return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
" else:\n",
" return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
"\n",
"print(f'Ticket bank: {len(ALL_TICKETS)} tickets')\n",
"print('Sample gold completion (T001, task2, step0):', gold_completion(ALL_TICKETS[22], 2, 0))\n",
"print('Sample gold completion (T001, task2, step1):', gold_completion(ALL_TICKETS[22], 2, 1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n",
"tokenizer.pad_token = tokenizer.eos_token\n",
"tokenizer.padding_side = 'right' # SFT uses right-padding\n",
"\n",
"SEEDS = list(range(200))\n",
"TASK_IDS = [1, 2, 3]\n",
"\n",
"def build_sft_dataset():\n",
" rows = []\n",
" rng = random.Random(42)\n",
" for task_id in TASK_IDS:\n",
" for seed in SEEDS:\n",
" ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
"\n",
" # Step 0: classify\n",
" messages_0 = make_prompt(ticket['text'], task_id, None, 'New ticket. Classify it first.', 0)\n",
" completion_0 = gold_completion(ticket, task_id, 0)\n",
" # SFTTrainer expects 'text' column with full conversation\n",
" full_0 = tokenizer.apply_chat_template(\n",
" messages_0 + [{'role': 'assistant', 'content': completion_0}],\n",
" tokenize=False\n",
" )\n",
" rows.append({'text': full_0, 'task_id': task_id, 'step': 0})\n",
"\n",
" # Step 1: resolve (tasks 2 & 3)\n",
" if task_id in (2, 3):\n",
" messages_1 = make_prompt(\n",
" ticket['text'], task_id,\n",
" ticket['category'],\n",
" f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n",
" 1\n",
" )\n",
" completion_1 = gold_completion(ticket, task_id, 1)\n",
" full_1 = tokenizer.apply_chat_template(\n",
" messages_1 + [{'role': 'assistant', 'content': completion_1}],\n",
" tokenize=False\n",
" )\n",
" rows.append({'text': full_1, 'task_id': task_id, 'step': 1})\n",
"\n",
" rng.shuffle(rows)\n",
" # Split 90/10 train/val\n",
" split = int(len(rows) * 0.9)\n",
" return Dataset.from_list(rows[:split]), Dataset.from_list(rows[split:])\n",
"\n",
"train_ds, val_ds = build_sft_dataset()\n",
"print(f'SFT dataset β train: {len(train_ds)}, val: {len(val_ds)}')\n",
"print('Sample text (first 400 chars):')\n",
"print(train_ds[0]['text'][:400])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM\n",
"from peft import LoraConfig, TaskType\n",
"from trl import SFTConfig, SFTTrainer\n",
"\n",
"print(f'Loading {MODEL_NAME}...')\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL_NAME,\n",
" dtype=torch.float16,\n",
" device_map={'': 0},\n",
" token=HF_TOKEN,\n",
")\n",
"model.config.use_cache = False\n",
"\n",
"peft_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" r=16,\n",
" lora_alpha=32,\n",
" target_modules=['q_proj','v_proj','k_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
" lora_dropout=0.05,\n",
" bias='none',\n",
")\n",
"\n",
"sft_config = SFTConfig(\n",
" output_dir=SFT_OUT,\n",
" num_train_epochs=3,\n",
" per_device_train_batch_size=4,\n",
" per_device_eval_batch_size=4,\n",
" gradient_accumulation_steps=2,\n",
" learning_rate=2e-4,\n",
" lr_scheduler_type='cosine',\n",
" warmup_ratio=0.05,\n",
" weight_decay=0.01,\n",
" max_grad_norm=1.0,\n",
" fp16=True,\n",
" bf16=False,\n",
" logging_steps=10,\n",
" eval_strategy='steps',\n",
" eval_steps=50,\n",
" save_strategy='no',\n",
" report_to='none',\n",
" dataloader_pin_memory=False,\n",
" ddp_find_unused_parameters=False,\n",
")\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=sft_config,\n",
" train_dataset=train_ds,\n",
" eval_dataset=val_ds,\n",
" peft_config=peft_config,\n",
" processing_class=tokenizer,\n",
" max_seq_length=512,\n",
" dataset_text_field='text',\n",
")\n",
"\n",
"print(f'SFTTrainer ready | train={len(train_ds)} val={len(val_ds)}')\n",
"print('Starting SFT...')\n",
"print('=' * 60)\n",
"\n",
"sft_result = trainer.train()\n",
"\n",
"print('=' * 60)\n",
"print(f'SFT complete! Loss: {sft_result.training_loss:.4f}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.makedirs(SFT_OUT, exist_ok=True)\n",
"trainer.save_model(SFT_OUT)\n",
"tokenizer.save_pretrained(SFT_OUT)\n",
"print(f'SFT model saved to {SFT_OUT}')\n",
"\n",
"# Push to HF Hub\n",
"try:\n",
" from huggingface_hub import HfApi\n",
" api = HfApi(token=HF_TOKEN)\n",
" api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
" api.upload_folder(folder_path=SFT_OUT, repo_id=HF_REPO_ID, repo_type='model')\n",
" print(f'SFT model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
"except Exception as e:\n",
" print(f'HF push failed: {e}')\n",
"\n",
"print('\\nβ
SFT done. Now run train_grpo.ipynb and set MODEL_NAME to:')\n",
"print(f' MODEL_NAME = \"{SFT_OUT}\" # or \"{HF_REPO_ID}\"')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Quick eval to verify SFT worked\n",
"import re\n",
"\n",
"model.eval()\n",
"model.config.use_cache = True\n",
"\n",
"def sft_generate(ticket_text, task_id, step=0, current_category=None):\n",
" messages = make_prompt(ticket_text, task_id, current_category, 'New ticket.', step)\n",
" prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512).to(DEVICE)\n",
" with torch.no_grad():\n",
" out = model.generate(\n",
" **inputs, max_new_tokens=80, do_sample=False,\n",
" pad_token_id=tokenizer.eos_token_id, use_cache=True\n",
" )\n",
" return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
"\n",
"print('=== SFT Quick Eval ===')\n",
"test_cases = [\n",
" ('I was charged twice this month.', 1, 0, None),\n",
" ('App crashes on file upload.', 1, 0, None),\n",
" ('Thank you, issue resolved!', 1, 0, None),\n",
" ('I want a refund.', 2, 0, None),\n",
" ('I want a refund.', 2, 1, 'refund'),\n",
" ('API returning 500 errors.', 3, 0, None),\n",
" ('API returning 500 errors.', 3, 1, 'technical'),\n",
"]\n",
"for text, tid, step, cat in test_cases:\n",
" out = sft_generate(text, tid, step, cat)\n",
" print(f'T{tid} step{step}: [{text[:35]}] -> {out[:80]}')"
]
}
]
} |