AlgoCore commited on
Commit
cf0d796
·
1 Parent(s): 42a3169

Add train_sft.ipynb: SFT pre-training with 1000 gold-label examples before GRPO

Browse files
Files changed (2) hide show
  1. train_grpo.ipynb +3 -0
  2. train_sft.ipynb +418 -0
train_grpo.ipynb CHANGED
@@ -74,6 +74,9 @@
74
  "\n",
75
  "ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
76
  "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
 
 
 
77
  "HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
78
  "\n",
79
  "RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
 
74
  "\n",
75
  "ENV_BASE_URL = 'https://algocore-support-ticket-env.hf.space'\n",
76
  "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
77
+ "# To use SFT pre-trained model instead (recommended - run train_sft.ipynb first):\n",
78
+ "# MODEL_NAME = '/kaggle/working/sft-model' # local SFT output\n",
79
+ "# MODEL_NAME = 'AlgoCore/support-ticket-sft-model' # HF Hub SFT model\n",
80
  "HF_REPO_ID = 'AlgoCore/support-ticket-grpo-model'\n",
81
  "\n",
82
  "RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
train_sft.ipynb ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "kernelspec": {"display_name": "Python 3", "name": "python3"},
6
+ "language_info": {"name": "python"},
7
+ "accelerator": "GPU"
8
+ },
9
+ "cells": [
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {},
13
+ "source": [
14
+ "# Support Ticket Env — SFT Pre-Training\n",
15
+ "**OpenEnv x Scalar Hackathon**\n",
16
+ "\n",
17
+ "Step 1 of 2-stage training: **Supervised Fine-Tuning** on gold-label examples.\n",
18
+ "This teaches the model correct JSON format + task logic before GRPO optimization.\n",
19
+ "\n",
20
+ "- **Model:** Qwen/Qwen2.5-0.5B-Instruct\n",
21
+ "- **Algorithm:** SFT via `trl.SFTTrainer`\n",
22
+ "- **Dataset:** 1000 gold-label (prompt, completion) pairs\n",
23
+ "- **Runtime:** ~15-20 min on Kaggle T4\n",
24
+ "- **Output:** `/kaggle/working/sft-model` → used as base for GRPO"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "!pip install -q 'trl>=0.18.2,<=0.24.0' 'transformers>=4.51.3,<=5.5.0' 'datasets>=3.4.1,<4.4.0' accelerate peft\n",
34
+ "!pip install -q requests matplotlib\n",
35
+ "print('Installation complete')"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "import os, json, random\n",
45
+ "\n",
46
+ "HF_TOKEN = ''\n",
47
+ "try:\n",
48
+ " from google.colab import userdata\n",
49
+ " HF_TOKEN = userdata.get('HF_TOKEN') or ''\n",
50
+ "except: pass\n",
51
+ "if not HF_TOKEN:\n",
52
+ " try:\n",
53
+ " from kaggle_secrets import UserSecretsClient\n",
54
+ " HF_TOKEN = UserSecretsClient().get_secret('HF_TOKEN') or ''\n",
55
+ " except: pass\n",
56
+ "if not HF_TOKEN:\n",
57
+ " HF_TOKEN = os.environ.get('HF_TOKEN', '')\n",
58
+ "if not HF_TOKEN:\n",
59
+ " raise ValueError('HF_TOKEN not found.')\n",
60
+ "\n",
61
+ "print('HF_TOKEN loaded OK')\n",
62
+ "\n",
63
+ "MODEL_NAME = 'Qwen/Qwen2.5-0.5B-Instruct'\n",
64
+ "RUNTIME = 'kaggle' if os.path.exists('/kaggle/working') else 'colab'\n",
65
+ "SFT_OUT = '/kaggle/working/sft-model' if RUNTIME == 'kaggle' else '/content/sft-model'\n",
66
+ "RESULTS_IMG = '/kaggle/working/sft_results.png' if RUNTIME == 'kaggle' else '/content/sft_results.png'\n",
67
+ "HF_REPO_ID = 'AlgoCore/support-ticket-sft-model'\n",
68
+ "\n",
69
+ "os.environ['HF_TOKEN'] = HF_TOKEN\n",
70
+ "os.environ['HUGGING_FACE_HUB_TOKEN'] = HF_TOKEN\n",
71
+ "\n",
72
+ "import torch\n",
73
+ "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
74
+ "print('GPU:', torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NO GPU')\n",
75
+ "print('VRAM:', round(torch.cuda.get_device_properties(0).total_memory/1e9,1), 'GB')\n",
76
+ "print('Runtime:', RUNTIME, '| SFT output:', SFT_OUT)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# ── Gold-label ticket bank (50 tickets, all categories) ──────────────────\n",
86
+ "ALL_TICKETS = [\n",
87
+ " # billing (12)\n",
88
+ " {'id':'B001','text':'I was charged twice for my subscription this month.','category':'billing','correct_action':'reply'},\n",
89
+ " {'id':'B002','text':'My invoice shows a charge for a plan I never subscribed to.','category':'billing','correct_action':'escalate'},\n",
90
+ " {'id':'B003','text':'I need an itemised invoice for my company accounting department.','category':'billing','correct_action':'reply'},\n",
91
+ " {'id':'B004','text':'Why was I charged before my trial period ended?','category':'billing','correct_action':'reply'},\n",
92
+ " {'id':'B005','text':'I switched plans but was still billed at the old rate.','category':'billing','correct_action':'reply'},\n",
93
+ " {'id':'B006','text':'My payment method was charged three times in one day.','category':'billing','correct_action':'escalate'},\n",
94
+ " {'id':'B007','text':'I cancelled my plan but the charge still appeared this month.','category':'billing','correct_action':'reply'},\n",
95
+ " {'id':'B008','text':'Can you send me a receipt for my last payment?','category':'billing','correct_action':'reply'},\n",
96
+ " {'id':'B009','text':'I was charged in USD but I signed up for GBP billing.','category':'billing','correct_action':'reply'},\n",
97
+ " {'id':'B010','text':'The discount code I applied is not reflected in my invoice.','category':'billing','correct_action':'reply'},\n",
98
+ " {'id':'B011','text':'I need to update my billing address on the invoice.','category':'billing','correct_action':'reply'},\n",
99
+ " {'id':'B012','text':'My credit card was charged even though payment failed notification was sent.','category':'billing','correct_action':'escalate'},\n",
100
+ " # account (10)\n",
101
+ " {'id':'A001','text':'I cannot log into my account. Password reset email never arrives.','category':'account','correct_action':'reply'},\n",
102
+ " {'id':'A002','text':'How do I cancel my subscription? I cannot find the option.','category':'account','correct_action':'reply'},\n",
103
+ " {'id':'A003','text':'I want to change my email address associated with the account.','category':'account','correct_action':'reply'},\n",
104
+ " {'id':'A004','text':'My account was locked after too many failed login attempts.','category':'account','correct_action':'reply'},\n",
105
+ " {'id':'A005','text':'I accidentally deleted my account. Can it be restored?','category':'account','correct_action':'reply'},\n",
106
+ " {'id':'A006','text':'I need to transfer my account to a different email.','category':'account','correct_action':'reply'},\n",
107
+ " {'id':'A007','text':'Two-factor authentication is not working for my account.','category':'account','correct_action':'reply'},\n",
108
+ " {'id':'A008','text':'I cannot find where to download my data for GDPR purposes.','category':'account','correct_action':'reply'},\n",
109
+ " {'id':'A009','text':'My username was changed without my permission.','category':'account','correct_action':'escalate'},\n",
110
+ " {'id':'A010','text':'I want to upgrade my account from free to premium.','category':'account','correct_action':'reply'},\n",
111
+ " # technical (10)\n",
112
+ " {'id':'T001','text':'Your app crashes every time I upload a file larger than 10 MB.','category':'technical','correct_action':'escalate'},\n",
113
+ " {'id':'T002','text':'The API is returning 500 errors intermittently for 2 hours.','category':'technical','correct_action':'escalate'},\n",
114
+ " {'id':'T003','text':'The dashboard is completely blank after the latest update.','category':'technical','correct_action':'escalate'},\n",
115
+ " {'id':'T004','text':'Export to CSV is broken — it downloads an empty file.','category':'technical','correct_action':'escalate'},\n",
116
+ " {'id':'T005','text':'Notifications are not being delivered to my email or phone.','category':'technical','correct_action':'escalate'},\n",
117
+ " {'id':'T006','text':'The mobile app freezes on the login screen on iOS 17.','category':'technical','correct_action':'escalate'},\n",
118
+ " {'id':'T007','text':'Search functionality returns no results for any query.','category':'technical','correct_action':'escalate'},\n",
119
+ " {'id':'T008','text':'Data sync between devices stopped working 3 days ago.','category':'technical','correct_action':'escalate'},\n",
120
+ " {'id':'T009','text':'The webhook integration keeps timing out and losing events.','category':'technical','correct_action':'escalate'},\n",
121
+ " {'id':'T010','text':'Browser extension throws a JavaScript error on every page load.','category':'technical','correct_action':'escalate'},\n",
122
+ " # refund (8)\n",
123
+ " {'id':'R001','text':'I want a full refund. I have not used the service at all.','category':'refund','correct_action':'reply'},\n",
124
+ " {'id':'R002','text':'I was double charged and need a refund for the extra payment.','category':'refund','correct_action':'reply'},\n",
125
+ " {'id':'R003','text':'The product did not work as advertised. I want my money back.','category':'refund','correct_action':'reply'},\n",
126
+ " {'id':'R004','text':'I cancelled within the 30-day window but have not received my refund.','category':'refund','correct_action':'reply'},\n",
127
+ " {'id':'R005','text':'I would like a partial refund for the unused months of my annual plan.','category':'refund','correct_action':'reply'},\n",
128
+ " {'id':'R006','text':'A refund was promised by your support agent 2 weeks ago but never arrived.','category':'refund','correct_action':'escalate'},\n",
129
+ " {'id':'R007','text':'I need a refund processed urgently as it was a fraudulent charge.','category':'refund','correct_action':'escalate'},\n",
130
+ " {'id':'R008','text':'How long does a refund take to appear on my credit card?','category':'refund','correct_action':'reply'},\n",
131
+ " # general (10)\n",
132
+ " {'id':'G001','text':'What are your business hours and do you have a phone number?','category':'general','correct_action':'reply'},\n",
133
+ " {'id':'G002','text':'Thank you! The issue has been resolved. You guys are awesome.','category':'general','correct_action':'close'},\n",
134
+ " {'id':'G003','text':'Do you offer a student discount or non-profit pricing?','category':'general','correct_action':'reply'},\n",
135
+ " {'id':'G004','text':'Where can I find your terms of service and privacy policy?','category':'general','correct_action':'reply'},\n",
136
+ " {'id':'G005','text':'Is your service available in my country? I am based in Brazil.','category':'general','correct_action':'reply'},\n",
137
+ " {'id':'G006','text':'Can I use your product for commercial purposes?','category':'general','correct_action':'reply'},\n",
138
+ " {'id':'G007','text':'Problem resolved, thanks for the quick response!','category':'general','correct_action':'close'},\n",
139
+ " {'id':'G008','text':'Do you have an affiliate or referral program?','category':'general','correct_action':'reply'},\n",
140
+ " {'id':'G009','text':'What integrations do you support with third-party tools?','category':'general','correct_action':'reply'},\n",
141
+ " {'id':'G010','text':'I just wanted to say your product has been amazing for our team.','category':'general','correct_action':'close'},\n",
142
+ "]\n",
143
+ "\n",
144
+ "# Gold reply templates per category\n",
145
+ "GOLD_REPLIES = {\n",
146
+ " 'billing': 'Thank you for reaching out about this billing issue. I can see the charge on your account. Our billing team will review your invoice and process any corrections. You will receive an updated receipt via email within 24 hours.',\n",
147
+ " 'account': 'Thank you for contacting us about your account. I understand how frustrating this can be. Please try resetting your password using the link sent to your registered email. If the issue persists, our account team will assist you directly.',\n",
148
+ " 'technical': 'I am escalating this technical issue to our engineering team immediately. They will investigate the root cause and provide a fix. You will receive a follow-up within 2 business hours.',\n",
149
+ " 'refund': 'I understand your refund request. Our refund policy allows for full refunds within 30 days. I am processing your refund now and it will appear on your credit card within 5-7 business days.',\n",
150
+ " 'general': 'Thank you for reaching out! Our support team is available Monday to Friday, 9am-6pm. You can also contact us via phone or email. We are happy to help with any questions you have.',\n",
151
+ " 'close': 'Thank you for letting us know the issue has been resolved. We are glad we could help! Please do not hesitate to contact us if you need anything else.',\n",
152
+ "}\n",
153
+ "\n",
154
+ "SYSTEM_PROMPT = '''You are a customer support AI agent. Respond ONLY with a JSON object.\n",
155
+ "\n",
156
+ "VALID action_type values: classify, reply, escalate, close\n",
157
+ "VALID category values: billing, technical, account, general, refund\n",
158
+ "\n",
159
+ "For classify: {\"action_type\": \"classify\", \"category\": \"<category>\"}\n",
160
+ "For reply: {\"action_type\": \"reply\", \"reply_text\": \"<response>\"}\n",
161
+ "For escalate: {\"action_type\": \"escalate\", \"reply_text\": \"Escalating to engineering.\"}\n",
162
+ "For close: {\"action_type\": \"close\", \"reply_text\": \"Closing ticket.\"}\n",
163
+ "\n",
164
+ "RULES:\n",
165
+ "- task_id=1: ALWAYS output action_type=classify first\n",
166
+ "- task_id=2: step=0 -> classify, step=1 -> reply/escalate/close\n",
167
+ "- task_id=3: step=0 -> classify, step=1 -> reply/escalate/close\n",
168
+ "- technical/crash/error/bug tickets -> escalate\n",
169
+ "- thank you/resolved tickets -> close\n",
170
+ "- billing/account/refund/general -> reply\n",
171
+ "- DO NOT use action_type=respond or action_type=resolve — those are INVALID'''\n",
172
+ "\n",
173
+ "def make_prompt(ticket_text, task_id, current_category=None, feedback='New ticket.', step=0):\n",
174
+ " user_msg = json.dumps({\n",
175
+ " 'ticket': ticket_text,\n",
176
+ " 'task_id': task_id,\n",
177
+ " 'current_category': current_category,\n",
178
+ " 'feedback': feedback,\n",
179
+ " 'step': step,\n",
180
+ " })\n",
181
+ " return [\n",
182
+ " {'role': 'system', 'content': SYSTEM_PROMPT},\n",
183
+ " {'role': 'user', 'content': user_msg},\n",
184
+ " ]\n",
185
+ "\n",
186
+ "def gold_completion(ticket, task_id, step):\n",
187
+ " \"\"\"Generate the perfect gold-label completion for a given ticket + step.\"\"\"\n",
188
+ " cat = ticket['category']\n",
189
+ " action = ticket['correct_action']\n",
190
+ "\n",
191
+ " if task_id == 1:\n",
192
+ " # Always classify\n",
193
+ " return json.dumps({'action_type': 'classify', 'category': cat})\n",
194
+ "\n",
195
+ " elif task_id == 2:\n",
196
+ " if step == 0:\n",
197
+ " return json.dumps({'action_type': 'classify', 'category': cat})\n",
198
+ " else:\n",
199
+ " if action == 'escalate':\n",
200
+ " return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
201
+ " elif action == 'close':\n",
202
+ " return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
203
+ " else:\n",
204
+ " return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
205
+ "\n",
206
+ " else: # task 3\n",
207
+ " if step == 0:\n",
208
+ " return json.dumps({'action_type': 'classify', 'category': cat})\n",
209
+ " else:\n",
210
+ " if action == 'escalate':\n",
211
+ " return json.dumps({'action_type': 'escalate', 'reply_text': GOLD_REPLIES['technical']})\n",
212
+ " elif action == 'close':\n",
213
+ " return json.dumps({'action_type': 'close', 'reply_text': GOLD_REPLIES['close']})\n",
214
+ " else:\n",
215
+ " return json.dumps({'action_type': 'reply', 'reply_text': GOLD_REPLIES[cat]})\n",
216
+ "\n",
217
+ "print(f'Ticket bank: {len(ALL_TICKETS)} tickets')\n",
218
+ "print('Sample gold completion (T001, task2, step0):', gold_completion(ALL_TICKETS[22], 2, 0))\n",
219
+ "print('Sample gold completion (T001, task2, step1):', gold_completion(ALL_TICKETS[22], 2, 1))"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "from datasets import Dataset\n",
229
+ "from transformers import AutoTokenizer\n",
230
+ "\n",
231
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)\n",
232
+ "tokenizer.pad_token = tokenizer.eos_token\n",
233
+ "tokenizer.padding_side = 'right' # SFT uses right-padding\n",
234
+ "\n",
235
+ "SEEDS = list(range(200))\n",
236
+ "TASK_IDS = [1, 2, 3]\n",
237
+ "\n",
238
+ "def build_sft_dataset():\n",
239
+ " rows = []\n",
240
+ " rng = random.Random(42)\n",
241
+ " for task_id in TASK_IDS:\n",
242
+ " for seed in SEEDS:\n",
243
+ " ticket = ALL_TICKETS[seed % len(ALL_TICKETS)]\n",
244
+ "\n",
245
+ " # Step 0: classify\n",
246
+ " messages_0 = make_prompt(ticket['text'], task_id, None, 'New ticket. Classify it first.', 0)\n",
247
+ " completion_0 = gold_completion(ticket, task_id, 0)\n",
248
+ " # SFTTrainer expects 'text' column with full conversation\n",
249
+ " full_0 = tokenizer.apply_chat_template(\n",
250
+ " messages_0 + [{'role': 'assistant', 'content': completion_0}],\n",
251
+ " tokenize=False\n",
252
+ " )\n",
253
+ " rows.append({'text': full_0, 'task_id': task_id, 'step': 0})\n",
254
+ "\n",
255
+ " # Step 1: resolve (tasks 2 & 3)\n",
256
+ " if task_id in (2, 3):\n",
257
+ " messages_1 = make_prompt(\n",
258
+ " ticket['text'], task_id,\n",
259
+ " ticket['category'],\n",
260
+ " f\"Category set to {ticket['category']}. Now resolve the ticket.\",\n",
261
+ " 1\n",
262
+ " )\n",
263
+ " completion_1 = gold_completion(ticket, task_id, 1)\n",
264
+ " full_1 = tokenizer.apply_chat_template(\n",
265
+ " messages_1 + [{'role': 'assistant', 'content': completion_1}],\n",
266
+ " tokenize=False\n",
267
+ " )\n",
268
+ " rows.append({'text': full_1, 'task_id': task_id, 'step': 1})\n",
269
+ "\n",
270
+ " rng.shuffle(rows)\n",
271
+ " # Split 90/10 train/val\n",
272
+ " split = int(len(rows) * 0.9)\n",
273
+ " return Dataset.from_list(rows[:split]), Dataset.from_list(rows[split:])\n",
274
+ "\n",
275
+ "train_ds, val_ds = build_sft_dataset()\n",
276
+ "print(f'SFT dataset — train: {len(train_ds)}, val: {len(val_ds)}')\n",
277
+ "print('Sample text (first 400 chars):')\n",
278
+ "print(train_ds[0]['text'][:400])"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "import torch\n",
288
+ "from transformers import AutoModelForCausalLM\n",
289
+ "from peft import LoraConfig, TaskType\n",
290
+ "from trl import SFTConfig, SFTTrainer\n",
291
+ "\n",
292
+ "print(f'Loading {MODEL_NAME}...')\n",
293
+ "model = AutoModelForCausalLM.from_pretrained(\n",
294
+ " MODEL_NAME,\n",
295
+ " dtype=torch.float16,\n",
296
+ " device_map={'': 0},\n",
297
+ " token=HF_TOKEN,\n",
298
+ ")\n",
299
+ "model.config.use_cache = False\n",
300
+ "\n",
301
+ "peft_config = LoraConfig(\n",
302
+ " task_type=TaskType.CAUSAL_LM,\n",
303
+ " r=16,\n",
304
+ " lora_alpha=32,\n",
305
+ " target_modules=['q_proj','v_proj','k_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
306
+ " lora_dropout=0.05,\n",
307
+ " bias='none',\n",
308
+ ")\n",
309
+ "\n",
310
+ "sft_config = SFTConfig(\n",
311
+ " output_dir=SFT_OUT,\n",
312
+ " num_train_epochs=3,\n",
313
+ " per_device_train_batch_size=4,\n",
314
+ " per_device_eval_batch_size=4,\n",
315
+ " gradient_accumulation_steps=2,\n",
316
+ " learning_rate=2e-4,\n",
317
+ " lr_scheduler_type='cosine',\n",
318
+ " warmup_ratio=0.05,\n",
319
+ " weight_decay=0.01,\n",
320
+ " max_grad_norm=1.0,\n",
321
+ " fp16=True,\n",
322
+ " bf16=False,\n",
323
+ " logging_steps=10,\n",
324
+ " eval_strategy='steps',\n",
325
+ " eval_steps=50,\n",
326
+ " save_strategy='no',\n",
327
+ " report_to='none',\n",
328
+ " max_seq_length=512,\n",
329
+ " dataset_text_field='text',\n",
330
+ " dataloader_pin_memory=False,\n",
331
+ " ddp_find_unused_parameters=False,\n",
332
+ ")\n",
333
+ "\n",
334
+ "trainer = SFTTrainer(\n",
335
+ " model=model,\n",
336
+ " args=sft_config,\n",
337
+ " train_dataset=train_ds,\n",
338
+ " eval_dataset=val_ds,\n",
339
+ " peft_config=peft_config,\n",
340
+ " processing_class=tokenizer,\n",
341
+ ")\n",
342
+ "\n",
343
+ "print(f'SFTTrainer ready | train={len(train_ds)} val={len(val_ds)}')\n",
344
+ "print('Starting SFT...')\n",
345
+ "print('=' * 60)\n",
346
+ "\n",
347
+ "sft_result = trainer.train()\n",
348
+ "\n",
349
+ "print('=' * 60)\n",
350
+ "print(f'SFT complete! Loss: {sft_result.training_loss:.4f}')"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "import os\n",
360
+ "os.makedirs(SFT_OUT, exist_ok=True)\n",
361
+ "trainer.save_model(SFT_OUT)\n",
362
+ "tokenizer.save_pretrained(SFT_OUT)\n",
363
+ "print(f'SFT model saved to {SFT_OUT}')\n",
364
+ "\n",
365
+ "# Push to HF Hub\n",
366
+ "try:\n",
367
+ " from huggingface_hub import HfApi\n",
368
+ " api = HfApi(token=HF_TOKEN)\n",
369
+ " api.create_repo(HF_REPO_ID, exist_ok=True, private=False)\n",
370
+ " api.upload_folder(folder_path=SFT_OUT, repo_id=HF_REPO_ID, repo_type='model')\n",
371
+ " print(f'SFT model pushed to: https://huggingface.co/{HF_REPO_ID}')\n",
372
+ "except Exception as e:\n",
373
+ " print(f'HF push failed: {e}')\n",
374
+ "\n",
375
+ "print('\\n✅ SFT done. Now run train_grpo.ipynb and set MODEL_NAME to:')\n",
376
+ "print(f' MODEL_NAME = \"{SFT_OUT}\" # or \"{HF_REPO_ID}\"')"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "# Quick eval to verify SFT worked\n",
386
+ "import re\n",
387
+ "\n",
388
+ "model.eval()\n",
389
+ "model.config.use_cache = True\n",
390
+ "\n",
391
+ "def sft_generate(ticket_text, task_id, step=0, current_category=None):\n",
392
+ " messages = make_prompt(ticket_text, task_id, current_category, 'New ticket.', step)\n",
393
+ " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
394
+ " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=512).to(DEVICE)\n",
395
+ " with torch.no_grad():\n",
396
+ " out = model.generate(\n",
397
+ " **inputs, max_new_tokens=80, do_sample=False,\n",
398
+ " pad_token_id=tokenizer.eos_token_id, use_cache=True\n",
399
+ " )\n",
400
+ " return tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
401
+ "\n",
402
+ "print('=== SFT Quick Eval ===')\n",
403
+ "test_cases = [\n",
404
+ " ('I was charged twice this month.', 1, 0, None),\n",
405
+ " ('App crashes on file upload.', 1, 0, None),\n",
406
+ " ('Thank you, issue resolved!', 1, 0, None),\n",
407
+ " ('I want a refund.', 2, 0, None),\n",
408
+ " ('I want a refund.', 2, 1, 'refund'),\n",
409
+ " ('API returning 500 errors.', 3, 0, None),\n",
410
+ " ('API returning 500 errors.', 3, 1, 'technical'),\n",
411
+ "]\n",
412
+ "for text, tid, step, cat in test_cases:\n",
413
+ " out = sft_generate(text, tid, step, cat)\n",
414
+ " print(f'T{tid} step{step}: [{text[:35]}] -> {out[:80]}')"
415
+ ]
416
+ }
417
+ ]
418
+ }