{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Support Ticket Env — SFT Pre-Training\n", "**OpenEnv x Scalar Hackathon**\n", "\n", "Step 1 of 2-stage training: **Supervised Fine-Tuning** on gold-label examples.\n", "This teaches the model correct JSON format + task logic before GRPO optimization.\n", "\n", "- **Model:** Qwen/Qwen2.5-0.5B-Instruct\n", "- **Algorithm:** SFT via `trl.SFTTrainer`\n", "- **Dataset:** 1000 gold-label (prompt, completion) pairs\n", "- **Runtime:** ~15-20 min on Kaggle T4\n", "- **Output:** `/kaggle/working/sft-model` → used as base for GRPO" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n", "!pip install -q requests matplotlib\n", "print('Installation complete')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, json, random\n", "\n", "HF_TOKEN = ''\n", "try:\n", " from google.colab import userdata\n", " HF_TOKEN = userdata.get('HF_TOKEN') or ''\n", "except: pass\n", "if not HF_TOKEN:\n", " try:\n", " from kaggle_secrets import UserSecretsClient\n", " HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n", " except: pass\n", "if not HF_TOKEN:\n", " HF_TOKEN = os.environ.get('HF_TOKEN', '')\n", "if not HF_TOKEN:\n", " raise ValueError('HF_TOKEN not found.')\n", "\n", "print('HF_TOKEN loaded OK')\n", "\n", "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n", "RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n", "SFT_OUT = '/kaggle/working/sft-model' if RUNTIME == 'kaggle' else '/content/sft-model'\n", "RESULTS_IMG = '/kaggle/working/sft_results.png' if RUNTIME == 'kaggle' else '/content/sft_results.png'\n", "HF_REPO_ID = 'AlgoCore/support-ticket-sft-model'\n", "\n", "os.environ['HF_TOKEN'] = HF_TOKEN\n", "os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n", "\n", "import torch\n", "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", "print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU')\n", "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9,1), 'GB')\n", "print('Runtime:', RUNTIME, '| SFT output:', SFT_OUT)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ── Gold-label ticket bank (50 tickets, all categories) ──────────────────\n", "ALL_TICKETS = [\n", " # billing (12)\n", " {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n", " {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n", " {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n", " {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n", " {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n", " {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n", " {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n", " {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n", " {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n", " {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n", " {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n", " {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n", " # account (10)\n", " {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n", " {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n", " {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n", " {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n", " {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n", " {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n", " {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n", " {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n", " {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n", " {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n", " # technical (10)\n", " {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n", " {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n", " {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n", " {'id':'T004','text':'Export to CSV is broken — it downloads an empty file.','category':'technical','correct_action':'escalate'},\n", " {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n", " {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n", " {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n", " {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n", " {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n", " {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n", " # refund (8)\n", " {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n", " {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n", " {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n", " {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n", " {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n", " {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n", " {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n", " {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n", " # general (10)\n", " {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n", " {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n", " {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n", " {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n", " {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n", " {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n", " {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n", " {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n", " {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n", " {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n", "]\n", "\n", "# Gold reply templates per category\n", "GOLD_REPLIES = {\n", " 'billing': 'Thank you for reaching out about this billing issue. I can see the charge on your account. Our billing team will review your invoice and process any corrections. You will receive an updated receipt via email within 24 hours.',\n", " 'account': 'Thank you for contacting us about your account. I understand how frustrating this can be. Please try resetting your password using the link sent to your registered email. If the issue persists, our account team will assist you directly.',\n", " 'technical': 'I am escalating this technical issue to our engineering team immediately. They will investigate the root cause and provide a fix. You will receive a follow-up within 2 business hours.',\n", " 'refund': 'I understand your refund request. Our refund policy allows for full refunds within 30 days. I am processing your refund now and it will appear on your credit card within 5-7 business days.',\n", " 'general': 'Thank you for reaching out! Our support team is available Monday to Friday, 9am-6pm. You can also contact us via phone or email. We are happy to help with any questions you have.',\n", " 'close': 'Thank you for letting us know the issue has been resolved. We are glad we could help! Please do not hesitate to contact us if you need anything else.',\n", "}\n", "\n", "SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n", "\n", "VALID action_type values: classify, reply, escalate, close\n", "VALID category values: billing, technical, account, general, refund\n", "\n", "For classify: {\"action_type\": \"classify\", \"category\": \"\"}\n", "For reply: {\"action_type\": \"reply\", \"reply_text\": \"\"}\n", "For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n", "For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n", "\n", "RULES:\n", "- task_id=1: ALWAYS output action_type=classify first\n", "- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n", "- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n", "- technical/crash/error/bug tickets -> escalate\n", "- thank you/resolved tickets -> close\n", "- billing/account/refund/general -> reply\n", "- DO NOT use action_type=respond or action_type=resolve — those are INVALID'''\n", "\n", "def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n", " user_msg = json.dumps({\n", " 'ticket': ticket_text,\n", " 'task_id': task_id,\n", " 'current_category': current_category,\n", " 'feedback': feedback,\n", " 'step': step,\n", " })\n", " return [\n", " {'role': 'system', 'content': SYSTEM_PROMPT},\n", " {'role': 'user', 'content': user_msg},\n", " ]\n", "\n", "def gold_completion(ticket, task_id, step):\n", " \"\"\"Generate the perfect gold-label completion for a given ticket + step.\"\"\"\n", " cat = ticket['category']\n", " action = ticket['correct_action']\n", "\n", " if task_id == 1:\n", " # Always classify\n", " return json.dumps({'action_type': 'classify', 'category': cat})\n", "\n", " elif task_id == 2:\n", " if step == 0:\n", " return json.dumps({'action_type': 'classify', 'category': cat})\n", " else:\n", " if action == 'escalate':\n", " return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n", " elif action == 'close':\n", " return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n", " else:\n", " return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n", "\n", " else: # task 3\n", " if step == 0:\n", " return json.dumps({'action_type': 'classify', 'category': cat})\n", " else:\n", " if action == 'escalate':\n", " return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n", " elif action == 'close':\n", " return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n", " else:\n", " return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n", "\n", "print(f'Ticket bank: {len(ALL_TICKETS)} tickets')\n", "print('Sample gold completion (T001, task2, step0):', gold_completion(ALL_TICKETS[22], 2, 0))\n", "print('Sample gold completion (T001, task2, step1):', gold_completion(ALL_TICKETS[22], 2, 1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import Dataset\n", "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = 'right' # SFT uses right-padding\n", "\n", "SEEDS = list(range(200))\n", "TASK_IDS = [1, 2, 3]\n", "\n", "def build_sft_dataset():\n", " rows = []\n", " rng = random.Random(42)\n", " for task_id in TASK_IDS:\n", " for seed in SEEDS:\n", " ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n", "\n", " # Step 0: classify\n", " messages_0 = make_prompt(ticket['text'], task_id, None, 'New ticket. Classify it first.', 0)\n", " completion_0 = gold_completion(ticket, task_id, 0)\n", " # SFTTrainer expects 'text' column with full conversation\n", " full_0 = tokenizer.apply_chat_template(\n", " messages_0 + [{'role': 'assistant', 'content': completion_0}],\n", " tokenize=False\n", " )\n", " rows.append({'text': full_0, 'task_id': task_id, 'step': 0})\n", "\n", " # Step 1: resolve (tasks 2 & 3)\n", " if task_id in (2, 3):\n", " messages_1 = make_prompt(\n", " ticket['text'], task_id,\n", " ticket['category'],\n", " f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n", " 1\n", " )\n", " completion_1 = gold_completion(ticket, task_id, 1)\n", " full_1 = tokenizer.apply_chat_template(\n", " messages_1 + [{'role': 'assistant', 'content': completion_1}],\n", " tokenize=False\n", " )\n", " rows.append({'text': full_1, 'task_id': task_id, 'step': 1})\n", "\n", " rng.shuffle(rows)\n", " # Split 90/10 train/val\n", " split = int(len(rows) * 0.9)\n", " return Dataset.from_list(rows[:split]), Dataset.from_list(rows[split:])\n", "\n", "train_ds, val_ds = build_sft_dataset()\n", "print(f'SFT dataset — train: {len(train_ds)}, val: {len(val_ds)}')\n", "print('Sample text (first 400 chars):')\n", "print(train_ds[0]['text'][:400])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM\n", "from peft import LoraConfig, TaskType\n", "from trl import SFTConfig, SFTTrainer\n", "\n", "print(f'Loading {MODEL_NAME}...')\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_NAME,\n", " dtype=torch.float16,\n", " device_map={'': 0},\n", " token=HF_TOKEN,\n", ")\n", "model.config.use_cache = False\n", "\n", "peft_config = LoraConfig(\n", " task_type=TaskType.CAUSAL_LM,\n", " r=16,\n", " lora_alpha=32,\n", " target_modules=['q_proj','v_proj','k_proj','o_proj','gate_proj','up_proj','down_proj'],\n", " lora_dropout=0.05,\n", " bias='none',\n", ")\n", "\n", "sft_config = SFTConfig(\n", " output_dir=SFT_OUT,\n", " num_train_epochs=3,\n", " per_device_train_batch_size=4,\n", " per_device_eval_batch_size=4,\n", " gradient_accumulation_steps=2,\n", " learning_rate=2e-4,\n", " lr_scheduler_type='cosine',\n", " warmup_ratio=0.05,\n", " weight_decay=0.01,\n", " max_grad_norm=1.0,\n", " fp16=True,\n", " bf16=False,\n", " logging_steps=10,\n", " eval_strategy='steps',\n", " eval_steps=50,\n", " save_strategy='no',\n", " report_to='none',\n", " dataloader_pin_memory=False,\n", " ddp_find_unused_parameters=False,\n", ")\n", "\n", "trainer = SFTTrainer(\n", " model=model,\n", " args=sft_config,\n", " train_dataset=train_ds,\n", " eval_dataset=val_ds,\n", " peft_config=peft_config,\n", " processing_class=tokenizer,\n", " max_seq_length=512,\n", " dataset_text_field='text',\n", ")\n", "\n", "print(f'SFTTrainer ready | train={len(train_ds)} val={len(val_ds)}')\n", "print('Starting SFT...')\n", "print('=' * 60)\n", "\n", "sft_result = trainer.train()\n", "\n", "print('=' * 60)\n", "print(f'SFT complete! Loss: {sft_result.training_loss:.4f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.makedirs(SFT_OUT, exist_ok=True)\n", "trainer.save_model(SFT_OUT)\n", "tokenizer.save_pretrained(SFT_OUT)\n", "print(f'SFT model saved to {SFT_OUT}')\n", "\n", "# Push to HF Hub\n", "try:\n", " from huggingface_hub import HfApi\n", " api = HfApi(token=HF_TOKEN)\n", " api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n", " api.upload_folder(folder_path=SFT_OUT, repo_id=HF_REPO_ID, repo_type='model')\n", " print(f'SFT model pushed to: https://huggingface.co/{HF_REPO_ID}')\n", "except Exception as e:\n", " print(f'HF push failed: {e}')\n", "\n", "print('\\n✅ SFT done. Now run train_grpo.ipynb and set MODEL_NAME to:')\n", "print(f' MODEL_NAME = \"{SFT_OUT}\" # or \"{HF_REPO_ID}\"')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Quick eval to verify SFT worked\n", "import re\n", "\n", "model.eval()\n", "model.config.use_cache = True\n", "\n", "def sft_generate(ticket_text, task_id, step=0, current_category=None):\n", " messages = make_prompt(ticket_text, task_id, current_category, 'New ticket.', step)\n", " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512).to(DEVICE)\n", " with torch.no_grad():\n", " out = model.generate(\n", " **inputs, max_new_tokens=80, do_sample=False,\n", " pad_token_id=tokenizer.eos_token_id, use_cache=True\n", " )\n", " return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", "\n", "print('=== SFT Quick Eval ===')\n", "test_cases = [\n", " ('I was charged twice this month.', 1, 0, None),\n", " ('App crashes on file upload.', 1, 0, None),\n", " ('Thank you, issue resolved!', 1, 0, None),\n", " ('I want a refund.', 2, 0, None),\n", " ('I want a refund.', 2, 1, 'refund'),\n", " ('API returning 500 errors.', 3, 0, None),\n", " ('API returning 500 errors.', 3, 1, 'technical'),\n", "]\n", "for text, tid, step, cat in test_cases:\n", " out = sft_generate(text, tid, step, cat)\n", " print(f'T{tid} step{step}: [{text[:35]}] -> {out[:80]}')" ] } ] }