Spaces:
Sleeping
Sleeping
Add train_sft.ipynb: SFT pre-training with 1000 gold-label examples before GRPO
Browse files- train_grpo.ipynb +3 -0
- train_sft.ipynb +418 -0
train_grpo.ipynb
CHANGED
|
@@ -74,6 +74,9 @@
|
|
| 74 |
"\n",
|
| 75 |
"ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
|
| 76 |
"MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
|
|
|
|
|
|
|
|
|
| 77 |
"HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
|
| 78 |
"\n",
|
| 79 |
"RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
|
|
|
|
| 74 |
"\n",
|
| 75 |
"ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
|
| 76 |
"MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
| 77 |
+
"# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n",
|
| 78 |
+
"# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n",
|
| 79 |
+
"# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\n",
|
| 80 |
"HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
|
| 81 |
"\n",
|
| 82 |
"RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
|
train_sft.ipynb
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"kernelspec": {"display_name": "Python 3", "name": "python3"},
|
| 6 |
+
"language_info": {"name": "python"},
|
| 7 |
+
"accelerator": "GPU"
|
| 8 |
+
},
|
| 9 |
+
"cells": [
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "markdown",
|
| 12 |
+
"metadata": {},
|
| 13 |
+
"source": [
|
| 14 |
+
"# Support Ticket Env — SFT Pre-Training\n",
|
| 15 |
+
"**OpenEnv x Scalar Hackathon**\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"Step 1 of 2-stage training: **Supervised Fine-Tuning** on gold-label examples.\n",
|
| 18 |
+
"This teaches the model correct JSON format + task logic before GRPO optimization.\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"- **Model:** Qwen/Qwen2.5-0.5B-Instruct\n",
|
| 21 |
+
"- **Algorithm:** SFT via `trl.SFTTrainer`\n",
|
| 22 |
+
"- **Dataset:** 1000 gold-label (prompt, completion) pairs\n",
|
| 23 |
+
"- **Runtime:** ~15-20 min on Kaggle T4\n",
|
| 24 |
+
"- **Output:** `/kaggle/working/sft-model` → used as base for GRPO"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": null,
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n",
|
| 34 |
+
"!pip install -q requests matplotlib\n",
|
| 35 |
+
"print('Installation complete')"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"import os, json, random\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"HF_TOKEN = ''\n",
|
| 47 |
+
"try:\n",
|
| 48 |
+
" from google.colab import userdata\n",
|
| 49 |
+
" HF_TOKEN = userdata.get('HF_TOKEN') or ''\n",
|
| 50 |
+
"except: pass\n",
|
| 51 |
+
"if not HF_TOKEN:\n",
|
| 52 |
+
" try:\n",
|
| 53 |
+
" from kaggle_secrets import UserSecretsClient\n",
|
| 54 |
+
" HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n",
|
| 55 |
+
" except: pass\n",
|
| 56 |
+
"if not HF_TOKEN:\n",
|
| 57 |
+
" HF_TOKEN = os.environ.get('HF_TOKEN', '')\n",
|
| 58 |
+
"if not HF_TOKEN:\n",
|
| 59 |
+
" raise ValueError('HF_TOKEN not found.')\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"print('HF_TOKEN loaded OK')\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
|
| 64 |
+
"RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
|
| 65 |
+
"SFT_OUT = '/kaggle/working/sft-model' if RUNTIME == 'kaggle' else '/content/sft-model'\n",
|
| 66 |
+
"RESULTS_IMG = '/kaggle/working/sft_results.png' if RUNTIME == 'kaggle' else '/content/sft_results.png'\n",
|
| 67 |
+
"HF_REPO_ID = 'AlgoCore/support-ticket-sft-model'\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"os.environ['HF_TOKEN'] = HF_TOKEN\n",
|
| 70 |
+
"os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"import torch\n",
|
| 73 |
+
"DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
|
| 74 |
+
"print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU')\n",
|
| 75 |
+
"print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9,1), 'GB')\n",
|
| 76 |
+
"print('Runtime:', RUNTIME, '| SFT output:', SFT_OUT)"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": null,
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": [
|
| 85 |
+
"# ── Gold-label ticket bank (50 tickets, all categories) ──────────────────\n",
|
| 86 |
+
"ALL_TICKETS = [\n",
|
| 87 |
+
" # billing (12)\n",
|
| 88 |
+
" {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
|
| 89 |
+
" {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
|
| 90 |
+
" {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
|
| 91 |
+
" {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n",
|
| 92 |
+
" {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n",
|
| 93 |
+
" {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n",
|
| 94 |
+
" {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n",
|
| 95 |
+
" {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n",
|
| 96 |
+
" {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n",
|
| 97 |
+
" {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n",
|
| 98 |
+
" {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n",
|
| 99 |
+
" {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n",
|
| 100 |
+
" # account (10)\n",
|
| 101 |
+
" {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
|
| 102 |
+
" {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
|
| 103 |
+
" {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n",
|
| 104 |
+
" {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n",
|
| 105 |
+
" {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n",
|
| 106 |
+
" {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n",
|
| 107 |
+
" {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n",
|
| 108 |
+
" {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n",
|
| 109 |
+
" {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n",
|
| 110 |
+
" {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n",
|
| 111 |
+
" # technical (10)\n",
|
| 112 |
+
" {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
|
| 113 |
+
" {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
|
| 114 |
+
" {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n",
|
| 115 |
+
" {'id':'T004','text':'Export to CSV is broken — it downloads an empty file.','category':'technical','correct_action':'escalate'},\n",
|
| 116 |
+
" {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n",
|
| 117 |
+
" {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n",
|
| 118 |
+
" {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n",
|
| 119 |
+
" {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n",
|
| 120 |
+
" {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n",
|
| 121 |
+
" {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n",
|
| 122 |
+
" # refund (8)\n",
|
| 123 |
+
" {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
|
| 124 |
+
" {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n",
|
| 125 |
+
" {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n",
|
| 126 |
+
" {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n",
|
| 127 |
+
" {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n",
|
| 128 |
+
" {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n",
|
| 129 |
+
" {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n",
|
| 130 |
+
" {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n",
|
| 131 |
+
" # general (10)\n",
|
| 132 |
+
" {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
|
| 133 |
+
" {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
|
| 134 |
+
" {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n",
|
| 135 |
+
" {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n",
|
| 136 |
+
" {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n",
|
| 137 |
+
" {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n",
|
| 138 |
+
" {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n",
|
| 139 |
+
" {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n",
|
| 140 |
+
" {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n",
|
| 141 |
+
" {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n",
|
| 142 |
+
"]\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# Gold reply templates per category\n",
|
| 145 |
+
"GOLD_REPLIES = {\n",
|
| 146 |
+
" 'billing': 'Thank you for reaching out about this billing issue. I can see the charge on your account. Our billing team will review your invoice and process any corrections. You will receive an updated receipt via email within 24 hours.',\n",
|
| 147 |
+
" 'account': 'Thank you for contacting us about your account. I understand how frustrating this can be. Please try resetting your password using the link sent to your registered email. If the issue persists, our account team will assist you directly.',\n",
|
| 148 |
+
" 'technical': 'I am escalating this technical issue to our engineering team immediately. They will investigate the root cause and provide a fix. You will receive a follow-up within 2 business hours.',\n",
|
| 149 |
+
" 'refund': 'I understand your refund request. Our refund policy allows for full refunds within 30 days. I am processing your refund now and it will appear on your credit card within 5-7 business days.',\n",
|
| 150 |
+
" 'general': 'Thank you for reaching out! Our support team is available Monday to Friday, 9am-6pm. You can also contact us via phone or email. We are happy to help with any questions you have.',\n",
|
| 151 |
+
" 'close': 'Thank you for letting us know the issue has been resolved. We are glad we could help! Please do not hesitate to contact us if you need anything else.',\n",
|
| 152 |
+
"}\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"VALID action_type values: classify, reply, escalate, close\n",
|
| 157 |
+
"VALID category values: billing, technical, account, general, refund\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"For classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\n",
|
| 160 |
+
"For reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\n",
|
| 161 |
+
"For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n",
|
| 162 |
+
"For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"RULES:\n",
|
| 165 |
+
"- task_id=1: ALWAYS output action_type=classify first\n",
|
| 166 |
+
"- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n",
|
| 167 |
+
"- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n",
|
| 168 |
+
"- technical/crash/error/bug tickets -> escalate\n",
|
| 169 |
+
"- thank you/resolved tickets -> close\n",
|
| 170 |
+
"- billing/account/refund/general -> reply\n",
|
| 171 |
+
"- DO NOT use action_type=respond or action_type=resolve — those are INVALID'''\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n",
|
| 174 |
+
" user_msg = json.dumps({\n",
|
| 175 |
+
" 'ticket': ticket_text,\n",
|
| 176 |
+
" 'task_id': task_id,\n",
|
| 177 |
+
" 'current_category': current_category,\n",
|
| 178 |
+
" 'feedback': feedback,\n",
|
| 179 |
+
" 'step': step,\n",
|
| 180 |
+
" })\n",
|
| 181 |
+
" return [\n",
|
| 182 |
+
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
| 183 |
+
" {'role': 'user', 'content': user_msg},\n",
|
| 184 |
+
" ]\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"def gold_completion(ticket, task_id, step):\n",
|
| 187 |
+
" \"\"\"Generate the perfect gold-label completion for a given ticket + step.\"\"\"\n",
|
| 188 |
+
" cat = ticket['category']\n",
|
| 189 |
+
" action = ticket['correct_action']\n",
|
| 190 |
+
"\n",
|
| 191 |
+
" if task_id == 1:\n",
|
| 192 |
+
" # Always classify\n",
|
| 193 |
+
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
|
| 194 |
+
"\n",
|
| 195 |
+
" elif task_id == 2:\n",
|
| 196 |
+
" if step == 0:\n",
|
| 197 |
+
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
|
| 198 |
+
" else:\n",
|
| 199 |
+
" if action == 'escalate':\n",
|
| 200 |
+
" return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
|
| 201 |
+
" elif action == 'close':\n",
|
| 202 |
+
" return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
|
| 203 |
+
" else:\n",
|
| 204 |
+
" return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
|
| 205 |
+
"\n",
|
| 206 |
+
" else: # task 3\n",
|
| 207 |
+
" if step == 0:\n",
|
| 208 |
+
" return json.dumps({'action_type': 'classify', 'category': cat})\n",
|
| 209 |
+
" else:\n",
|
| 210 |
+
" if action == 'escalate':\n",
|
| 211 |
+
" return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
|
| 212 |
+
" elif action == 'close':\n",
|
| 213 |
+
" return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
|
| 214 |
+
" else:\n",
|
| 215 |
+
" return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"print(f'Ticket bank: {len(ALL_TICKETS)} tickets')\n",
|
| 218 |
+
"print('Sample gold completion (T001, task2, step0):', gold_completion(ALL_TICKETS[22], 2, 0))\n",
|
| 219 |
+
"print('Sample gold completion (T001, task2, step1):', gold_completion(ALL_TICKETS[22], 2, 1))"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": null,
|
| 225 |
+
"metadata": {},
|
| 226 |
+
"outputs": [],
|
| 227 |
+
"source": [
|
| 228 |
+
"from datasets import Dataset\n",
|
| 229 |
+
"from transformers import AutoTokenizer\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n",
|
| 232 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
| 233 |
+
"tokenizer.padding_side = 'right' # SFT uses right-padding\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"SEEDS = list(range(200))\n",
|
| 236 |
+
"TASK_IDS = [1, 2, 3]\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"def build_sft_dataset():\n",
|
| 239 |
+
" rows = []\n",
|
| 240 |
+
" rng = random.Random(42)\n",
|
| 241 |
+
" for task_id in TASK_IDS:\n",
|
| 242 |
+
" for seed in SEEDS:\n",
|
| 243 |
+
" ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" # Step 0: classify\n",
|
| 246 |
+
" messages_0 = make_prompt(ticket['text'], task_id, None, 'New ticket. Classify it first.', 0)\n",
|
| 247 |
+
" completion_0 = gold_completion(ticket, task_id, 0)\n",
|
| 248 |
+
" # SFTTrainer expects 'text' column with full conversation\n",
|
| 249 |
+
" full_0 = tokenizer.apply_chat_template(\n",
|
| 250 |
+
" messages_0 + [{'role': 'assistant', 'content': completion_0}],\n",
|
| 251 |
+
" tokenize=False\n",
|
| 252 |
+
" )\n",
|
| 253 |
+
" rows.append({'text': full_0, 'task_id': task_id, 'step': 0})\n",
|
| 254 |
+
"\n",
|
| 255 |
+
" # Step 1: resolve (tasks 2 & 3)\n",
|
| 256 |
+
" if task_id in (2, 3):\n",
|
| 257 |
+
" messages_1 = make_prompt(\n",
|
| 258 |
+
" ticket['text'], task_id,\n",
|
| 259 |
+
" ticket['category'],\n",
|
| 260 |
+
" f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n",
|
| 261 |
+
" 1\n",
|
| 262 |
+
" )\n",
|
| 263 |
+
" completion_1 = gold_completion(ticket, task_id, 1)\n",
|
| 264 |
+
" full_1 = tokenizer.apply_chat_template(\n",
|
| 265 |
+
" messages_1 + [{'role': 'assistant', 'content': completion_1}],\n",
|
| 266 |
+
" tokenize=False\n",
|
| 267 |
+
" )\n",
|
| 268 |
+
" rows.append({'text': full_1, 'task_id': task_id, 'step': 1})\n",
|
| 269 |
+
"\n",
|
| 270 |
+
" rng.shuffle(rows)\n",
|
| 271 |
+
" # Split 90/10 train/val\n",
|
| 272 |
+
" split = int(len(rows) * 0.9)\n",
|
| 273 |
+
" return Dataset.from_list(rows[:split]), Dataset.from_list(rows[split:])\n",
|
| 274 |
+
"\n",
|
| 275 |
+
"train_ds, val_ds = build_sft_dataset()\n",
|
| 276 |
+
"print(f'SFT dataset — train: {len(train_ds)}, val: {len(val_ds)}')\n",
|
| 277 |
+
"print('Sample text (first 400 chars):')\n",
|
| 278 |
+
"print(train_ds[0]['text'][:400])"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"cell_type": "code",
|
| 283 |
+
"execution_count": null,
|
| 284 |
+
"metadata": {},
|
| 285 |
+
"outputs": [],
|
| 286 |
+
"source": [
|
| 287 |
+
"import torch\n",
|
| 288 |
+
"from transformers import AutoModelForCausalLM\n",
|
| 289 |
+
"from peft import LoraConfig, TaskType\n",
|
| 290 |
+
"from trl import SFTConfig, SFTTrainer\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"print(f'Loading {MODEL_NAME}...')\n",
|
| 293 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 294 |
+
" MODEL_NAME,\n",
|
| 295 |
+
" dtype=torch.float16,\n",
|
| 296 |
+
" device_map={'': 0},\n",
|
| 297 |
+
" token=HF_TOKEN,\n",
|
| 298 |
+
")\n",
|
| 299 |
+
"model.config.use_cache = False\n",
|
| 300 |
+
"\n",
|
| 301 |
+
"peft_config = LoraConfig(\n",
|
| 302 |
+
" task_type=TaskType.CAUSAL_LM,\n",
|
| 303 |
+
" r=16,\n",
|
| 304 |
+
" lora_alpha=32,\n",
|
| 305 |
+
" target_modules=['q_proj','v_proj','k_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
|
| 306 |
+
" lora_dropout=0.05,\n",
|
| 307 |
+
" bias='none',\n",
|
| 308 |
+
")\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"sft_config = SFTConfig(\n",
|
| 311 |
+
" output_dir=SFT_OUT,\n",
|
| 312 |
+
" num_train_epochs=3,\n",
|
| 313 |
+
" per_device_train_batch_size=4,\n",
|
| 314 |
+
" per_device_eval_batch_size=4,\n",
|
| 315 |
+
" gradient_accumulation_steps=2,\n",
|
| 316 |
+
" learning_rate=2e-4,\n",
|
| 317 |
+
" lr_scheduler_type='cosine',\n",
|
| 318 |
+
" warmup_ratio=0.05,\n",
|
| 319 |
+
" weight_decay=0.01,\n",
|
| 320 |
+
" max_grad_norm=1.0,\n",
|
| 321 |
+
" fp16=True,\n",
|
| 322 |
+
" bf16=False,\n",
|
| 323 |
+
" logging_steps=10,\n",
|
| 324 |
+
" eval_strategy='steps',\n",
|
| 325 |
+
" eval_steps=50,\n",
|
| 326 |
+
" save_strategy='no',\n",
|
| 327 |
+
" report_to='none',\n",
|
| 328 |
+
" max_seq_length=512,\n",
|
| 329 |
+
" dataset_text_field='text',\n",
|
| 330 |
+
" dataloader_pin_memory=False,\n",
|
| 331 |
+
" ddp_find_unused_parameters=False,\n",
|
| 332 |
+
")\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"trainer = SFTTrainer(\n",
|
| 335 |
+
" model=model,\n",
|
| 336 |
+
" args=sft_config,\n",
|
| 337 |
+
" train_dataset=train_ds,\n",
|
| 338 |
+
" eval_dataset=val_ds,\n",
|
| 339 |
+
" peft_config=peft_config,\n",
|
| 340 |
+
" processing_class=tokenizer,\n",
|
| 341 |
+
")\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"print(f'SFTTrainer ready | train={len(train_ds)} val={len(val_ds)}')\n",
|
| 344 |
+
"print('Starting SFT...')\n",
|
| 345 |
+
"print('=' * 60)\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"sft_result = trainer.train()\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"print('=' * 60)\n",
|
| 350 |
+
"print(f'SFT complete! Loss: {sft_result.training_loss:.4f}')"
|
| 351 |
+
]
|
| 352 |
+
},
|
| 353 |
+
{
|
| 354 |
+
"cell_type": "code",
|
| 355 |
+
"execution_count": null,
|
| 356 |
+
"metadata": {},
|
| 357 |
+
"outputs": [],
|
| 358 |
+
"source": [
|
| 359 |
+
"import os\n",
|
| 360 |
+
"os.makedirs(SFT_OUT, exist_ok=True)\n",
|
| 361 |
+
"trainer.save_model(SFT_OUT)\n",
|
| 362 |
+
"tokenizer.save_pretrained(SFT_OUT)\n",
|
| 363 |
+
"print(f'SFT model saved to {SFT_OUT}')\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"# Push to HF Hub\n",
|
| 366 |
+
"try:\n",
|
| 367 |
+
" from huggingface_hub import HfApi\n",
|
| 368 |
+
" api = HfApi(token=HF_TOKEN)\n",
|
| 369 |
+
" api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
|
| 370 |
+
" api.upload_folder(folder_path=SFT_OUT, repo_id=HF_REPO_ID, repo_type='model')\n",
|
| 371 |
+
" print(f'SFT model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
|
| 372 |
+
"except Exception as e:\n",
|
| 373 |
+
" print(f'HF push failed: {e}')\n",
|
| 374 |
+
"\n",
|
| 375 |
+
"print('\\n✅ SFT done. Now run train_grpo.ipynb and set MODEL_NAME to:')\n",
|
| 376 |
+
"print(f' MODEL_NAME = \"{SFT_OUT}\" # or \"{HF_REPO_ID}\"')"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
+
"metadata": {},
|
| 383 |
+
"outputs": [],
|
| 384 |
+
"source": [
|
| 385 |
+
"# Quick eval to verify SFT worked\n",
|
| 386 |
+
"import re\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"model.eval()\n",
|
| 389 |
+
"model.config.use_cache = True\n",
|
| 390 |
+
"\n",
|
| 391 |
+
"def sft_generate(ticket_text, task_id, step=0, current_category=None):\n",
|
| 392 |
+
" messages = make_prompt(ticket_text, task_id, current_category, 'New ticket.', step)\n",
|
| 393 |
+
" prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 394 |
+
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512).to(DEVICE)\n",
|
| 395 |
+
" with torch.no_grad():\n",
|
| 396 |
+
" out = model.generate(\n",
|
| 397 |
+
" **inputs, max_new_tokens=80, do_sample=False,\n",
|
| 398 |
+
" pad_token_id=tokenizer.eos_token_id, use_cache=True\n",
|
| 399 |
+
" )\n",
|
| 400 |
+
" return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"print('=== SFT Quick Eval ===')\n",
|
| 403 |
+
"test_cases = [\n",
|
| 404 |
+
" ('I was charged twice this month.', 1, 0, None),\n",
|
| 405 |
+
" ('App crashes on file upload.', 1, 0, None),\n",
|
| 406 |
+
" ('Thank you, issue resolved!', 1, 0, None),\n",
|
| 407 |
+
" ('I want a refund.', 2, 0, None),\n",
|
| 408 |
+
" ('I want a refund.', 2, 1, 'refund'),\n",
|
| 409 |
+
" ('API returning 500 errors.', 3, 0, None),\n",
|
| 410 |
+
" ('API returning 500 errors.', 3, 1, 'technical'),\n",
|
| 411 |
+
"]\n",
|
| 412 |
+
"for text, tid, step, cat in test_cases:\n",
|
| 413 |
+
" out = sft_generate(text, tid, step, cat)\n",
|
| 414 |
+
" print(f'T{tid} step{step}: [{text[:35]}] -> {out[:80]}')"
|
| 415 |
+
]
|
| 416 |
+
}
|
| 417 |
+
]
|
| 418 |
+
}
|