Phaedrus33 commited on
Commit
1948ad5
·
verified ·
1 Parent(s): 4cc6626

Upload train_grpo_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo_final.py +1131 -0
train_grpo_final.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Qwen3-32B with GRPO on reasoning traces.
3
+
4
+ Exactly 3 reward functions:
5
+ 1. boxed_reward: response contains \\boxed{...}
6
+ 2. think_tags_reward: response contains <think>...</think>
7
+ 3. accuracy_reward: extracted answer matches ground truth
8
+
9
+ Uses Unsloth for efficient training with vLLM fast inference.
10
+ Loads reasoning traces (2400 Type A from train.csv) with question-type-aware prompts.
11
+
12
+ Usage:
13
+ # Basic (SFT LoRA from HF)
14
+ python train_grpo_final.py --sft-model USERNAME/sft-lora
15
+
16
+ # With more steps
17
+ python train_grpo_final.py --sft-model USERNAME/sft-lora --max-steps 200
18
+
19
+ # Push merged model to HF after training
20
+ python train_grpo_final.py \\
21
+ --sft-model USERNAME/sft-lora \\
22
+ --push-to-hub --merge-16bit \\
23
+ --hf-repo USERNAME/grpo-final \\
24
+ --hf-token hf_xxx
25
+
26
+ # Dry run (validate data)
27
+ python train_grpo_final.py --sft-model ./path/to/sft --dry-run
28
+ """
29
+
30
+ # Enable Unsloth's memory-efficient standby mode for vLLM
31
+ # Must be set BEFORE importing unsloth
32
+ import os
33
+ os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
34
+ import json
35
+ import re
36
+ import argparse
37
+ import logging
38
+ from typing import Dict, List
39
+ from collections import Counter
40
+
41
+ from tqdm import tqdm
42
+ from datasets import Dataset
43
+
44
+ # Metric computation (same as SFT training)
45
+ from telco_utils import parse_type_a_question
46
+ from generate_traces_final import (
47
+ compute_all_metrics, format_metrics_block,
48
+ compute_type_b_metrics, format_type_b_metrics_block,
49
+ )
50
+
51
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ # =============================================================================
56
+ # PatchFastRL - Required for Unsloth + GRPOTrainer integration
57
+ # =============================================================================
58
+ from unsloth import FastLanguageModel, PatchFastRL
59
+ PatchFastRL("GRPO", FastLanguageModel)
60
+
61
+
62
+ # =============================================================================
63
+ # V19 SYSTEM PROMPTS
64
+ # =============================================================================
65
+
66
+ TELCO_SYSTEM_PROMPT = """You are a 5G network root cause classifier. You receive pre-computed metrics and a multiple-choice question. Walk through the decision rules below IN ORDER, show your work for each check, identify the root cause, then match it to the correct option label.
67
+
68
+ OUTPUT FORMAT (mandatory):
69
+ 1. Wrap ALL reasoning inside <think>...</think> tags.
70
+ 2. After </think>, output EXACTLY \\boxed{LABEL} where LABEL is the option (e.g. C3, 7, M5).
71
+ 3. Do NOT write anything after \\boxed{LABEL}. No explanation, no period, no newline text.
72
+ 4. Every response MUST end with \\boxed{LABEL}. Omitting it is a failure.
73
+
74
+ DECISION RULES (apply first matching rule):
75
+
76
+ TIER 1 - check in order, return first match:
77
+ 1. max_speed > 40 -> "speed exceeds 40km/h"
78
+ 2. max_distance_low_tp > 1.0 -> "coverage distance exceeds 1km" (overshooting)
79
+ 3. handover_count >= 3 -> "frequent handovers"
80
+ 4. avg_rb < 170 -> "average scheduled RBs below 160"
81
+
82
+ TIER 2 - C1 detection (if ANY sub-rule matches -> "downtilt too large"):
83
+ 5a. min_rsrp < -90 AND pci_collision = no AND c4_interference < 3
84
+ 5b. strong_neighbor_count < 0.5 AND serving_tilt >= 15
85
+ 5c. pci_collision = yes AND strong_neighbor_count < 0.5
86
+ OVERRIDES on C1:
87
+ - post_ho_good_streak >= 2 -> "neighboring cell higher throughput" instead
88
+ - pci_collision_ratio > 0.70 -> "PCI mod 30 collision" instead
89
+ - avg_rsrp > -79 AND strong_neighbor_count > 1.0 -> "neighboring cell higher throughput" instead
90
+
91
+ TIER 3 - interference:
92
+ 6. c4_interference >= 3 -> "overlapping coverage/interference"
93
+ SKIP if: (min_neighbor_diff / c4_interference) < -0.5 AND c4_interference < 12
94
+
95
+ TIER 4 - PCI collision (pci_collision = yes):
96
+ 7. If pci_collision_ratio >= 1.0 -> "PCI mod 30 collision"
97
+ If pci_collision_ratio < 1.0:
98
+ - serving_tilt > 10 AND rsrp_trend > 0.4 -> "downtilt too large"
99
+ - else -> "neighboring cell higher throughput"
100
+ If avg_off_axis > 30:
101
+ - min_rsrp < -90 -> "downtilt too large" (with override checks from Tier 2)
102
+ - else -> "neighboring cell higher throughput" (with override checks from Tier 5)
103
+
104
+ TIER 5 - C1 vs C3 tiebreaker:
105
+ 8. serving_tilt >= 28 AND avg_sinr >= 12 -> "neighboring cell higher throughput" (SINR gate)
106
+ serving_tilt >= 28 AND avg_sinr < 12 -> "downtilt too large"
107
+ serving_tilt < 12 -> "neighboring cell higher throughput"
108
+ avg_rsrp < -90 -> "downtilt too large"
109
+ avg_rsrp > -82 -> "neighboring cell higher throughput"
110
+ Low confidence (avg_rsrp -90 to -82) -> rescue rules:
111
+ R1: pci_collision_ratio >= 0.9 -> "PCI mod 30 collision"
112
+ R2: strong_neighbor_count < 0.8 -> "downtilt too large"
113
+ R3: c4_interference >= 3.0 -> "downtilt too large"
114
+ R4: default -> "neighboring cell higher throughput"
115
+ OVERRIDES if "neighboring cell" (high/medium confidence):
116
+ - pci_collision_ratio > 0.70 -> "PCI mod 30 collision"
117
+ - rsrp_change > 5 AND rsrp_trend > 0.5 AND nb_within_5db < 1.0 -> "downtilt too large"
118
+ - rsrp_recovery > 15 -> "downtilt too large"
119
+ - serving_tilt > 6 AND nb_within_5db < 1.0 -> "downtilt too large"
120
+ OVERRIDES if "downtilt" (high/medium confidence):
121
+ - pci_collision_ratio > 0.70 -> "PCI mod 30 collision"
122
+ - avg_rsrp > -79 AND strong_neighbor_count > 1.0 -> "neighboring cell higher throughput"
123
+
124
+ Show your reasoning inside <think> tags, checking each tier in order. Then match the identified root cause to the option that describes it and answer with EXACTLY \\boxed{LABEL}. You MUST always end your response with \\boxed{LABEL}.
125
+ Examples: \\boxed{C3}, \\boxed{7}, \\boxed{M5}"""
126
+
127
+
128
+ TYPE_B_SYSTEM_PROMPT = """You are a 5G drive test root cause analyzer. You receive pre-computed metrics and a multiple-choice question about throughput drops. Walk through the decision rules below IN ORDER, show your work for each check, identify the root cause, then match it to the correct option label.
129
+
130
+ OUTPUT FORMAT (mandatory):
131
+ 1. Wrap ALL reasoning inside <think>...</think> tags.
132
+ 2. After </think>, output EXACTLY \\boxed{LABEL} where LABEL is the option letter (e.g. A, D, G).
133
+ 3. Do NOT write anything after \\boxed{LABEL}. No explanation, no period, no newline text.
134
+ 4. Every response MUST end with \\boxed{LABEL}. Omitting it is a failure.
135
+
136
+ IMPORTANT: Options are SHUFFLED per question - identify the root cause FIRST, then find which option letter matches it.
137
+
138
+ DECISION RULES (apply first matching rule):
139
+
140
+ 1. avg_cce_fail > 0.25 -> "PDCCH congestion" (I)
141
+ 2. actual_handovers >= 3 -> "intra-freq threshold too low / ping-pong" (H)
142
+ 3. ratio_a3_ho >= 3 AND a3_events >= 2, OR rrc_reestablish > 0 AND a3_events >= 1:
143
+ -> Check n1_in_config:
144
+ If n1_in_config = False -> "missing neighbor cell configuration" (E)
145
+ If n1_in_config = True -> "intra-freq threshold too high" (G)
146
+ 4. rsrp_var_norm > 0.08 AND avg_rsrp > -95 -> "overlap coverage" (A)
147
+ 5. avg_rsrp < -95 -> "weak coverage" (F)
148
+
149
+ PHY HEALTH ANALYSIS (if no rule above matches):
150
+ 6. If phy_healthy_during_low_tp = True AND neighbors_within_3dB = 0 AND avg_sinr > 10:
151
+ -> "transport/server-side anomaly" (D)
152
+ Meaning: Radio link healthy during TP drops, bottleneck above PHY layer.
153
+ 7. If phy_healthy_during_low_tp = False AND low_tp_avg_mcs < 12 AND neighbors_within_3dB >= 1:
154
+ -> "overlap coverage" (A)
155
+ Meaning: MCS crashes with strong neighbor present = interference/pilot pollution.
156
+
157
+ CONFIGURATION CHECK (if no rule above matches):
158
+ 8. If inter_freq_ho = True AND a2_thld > -100 AND n_configured_neighbors >= 6:
159
+ -> "inter-freq HO threshold unreasonable" (B)
160
+ Meaning: Inter-frequency handover triggered with unreasonable A2 threshold.
161
+
162
+ Show your reasoning inside <think> tags. Then match the root cause to the option that describes it and answer with EXACTLY \\boxed{LABEL}. You MUST always end your response with \\boxed{LABEL}.
163
+ Examples: \\boxed{A}, \\boxed{D}, \\boxed{G}"""
164
+
165
+
166
+ GENERIC_SYSTEM_PROMPT = """You are an expert problem solver. Analyze questions carefully and select the correct answer.
167
+
168
+ IMPORTANT - Answer Format:
169
+ - Use the EXACT option number/label from the question
170
+ - Examples: \\boxed{2}, \\boxed{B}, \\boxed{72}
171
+
172
+ You must strictly output your reasoning process within <think>...</think> tags before the final answer."""
173
+
174
+
175
+ # =============================================================================
176
+ # ANSWER NORMALIZATION
177
+ # =============================================================================
178
+
179
+ def normalize_answer(answer: str) -> str:
180
+ """Normalize answer format for comparison. C1 -> 1, c1 -> 1, 1 -> 1."""
181
+ if not answer:
182
+ return ""
183
+ answer = answer.strip()
184
+ match = re.match(r'^[Cc](\d+)$', answer)
185
+ if match:
186
+ return match.group(1)
187
+ return answer
188
+
189
+
190
+ # =============================================================================
191
+ # REWARD FUNCTIONS (exactly 3)
192
+ # =============================================================================
193
+
194
+ BOXED_PATTERN = re.compile(r'\\boxed\s*\{\s*([^}]+?)\s*\}')
195
+
196
+
197
+ def boxed_reward(prompts, completions, **kwargs):
198
+ """
199
+ Reward for \\boxed{} presence.
200
+
201
+ Returns:
202
+ +0.5 if \\boxed{...} with content is present
203
+ -0.5 if missing
204
+ """
205
+ scores = []
206
+ for completion in completions:
207
+ response = completion[0]["content"]
208
+ if BOXED_PATTERN.search(response):
209
+ scores.append(0.5)
210
+ else:
211
+ scores.append(-0.5)
212
+ return scores
213
+
214
+
215
+ def think_tags_reward(prompts, completions, **kwargs):
216
+ """
217
+ Reward for <think>...</think> tags with non-trivial content.
218
+
219
+ Returns:
220
+ +1.0 if both tags present AND content >= 50 chars
221
+ -0.5 if tags present but content too short (degenerate)
222
+ -1.0 if either tag is missing
223
+ """
224
+ scores = []
225
+ for completion in completions:
226
+ response = completion[0]["content"]
227
+ if '<think>' in response and '</think>' in response:
228
+ # Extract content between think tags
229
+ start = response.index('<think>') + len('<think>')
230
+ end = response.index('</think>')
231
+ think_content = response[start:end].strip()
232
+ if len(think_content) >= 200:
233
+ scores.append(1.0)
234
+ else:
235
+ scores.append(-0.5)
236
+ else:
237
+ scores.append(-1.0)
238
+ return scores
239
+
240
+
241
+ def accuracy_reward(prompts, completions, answer, **kwargs):
242
+ """
243
+ Reward for correct answer matching ground truth.
244
+
245
+ Returns:
246
+ +5.0 for exact match
247
+ +3.0 for match after normalization (C1 == 1)
248
+ -2.0 for wrong answer
249
+ -3.0 for no answer extracted
250
+ """
251
+ scores = []
252
+ for completion, true_answer in zip(completions, answer):
253
+ response = completion[0]["content"]
254
+
255
+ match = BOXED_PATTERN.search(response)
256
+ if not match:
257
+ scores.append(-3.0)
258
+ continue
259
+
260
+ pred = match.group(1).strip()
261
+ true = true_answer.strip()
262
+
263
+ if pred == true:
264
+ scores.append(5.0)
265
+ continue
266
+
267
+ pred_norm = normalize_answer(pred)
268
+ true_norm = normalize_answer(true)
269
+ if pred_norm == true_norm:
270
+ scores.append(3.0)
271
+ continue
272
+
273
+ scores.append(-2.0)
274
+ return scores
275
+
276
+
277
+ # =============================================================================
278
+ # QUESTION TYPE DETECTION
279
+ # =============================================================================
280
+
281
+ def get_question_type(question: str, source_type: str = None) -> str:
282
+ """Detect question type: 'type_a', 'type_b', or 'generic'.
283
+
284
+ Uses source_type from trace data when available, falls back to heuristics.
285
+ """
286
+ if source_type and source_type in ('type_a', 'type_b', 'generic'):
287
+ return source_type
288
+
289
+ # Heuristic: Type B questions have drive test throughput drop analysis
290
+ if 'throughput drop' in question.lower() and 'drive test' in question.lower():
291
+ return 'type_b'
292
+
293
+ # Heuristic: Type A questions have telco data tables
294
+ if question.strip().startswith("Analyze the following question"):
295
+ if '|' in question and question.count('|') >= 4:
296
+ return 'type_a'
297
+
298
+ # Default to generic
299
+ return 'generic'
300
+
301
+
302
+ # =============================================================================
303
+ # DATA LOADING
304
+ # =============================================================================
305
+
306
+ def load_v19_traces(checkpoint_path: str) -> List[Dict]:
307
+ """Load V19 reasoning traces.
308
+
309
+ V19 traces have: question, expected_answer, reasoning_trace, question_type.
310
+ All traces are pre-validated (success=True equivalent).
311
+ """
312
+ logger.info(f"Loading V19 traces from {checkpoint_path}")
313
+
314
+ with open(checkpoint_path, 'r') as f:
315
+ checkpoint = json.load(f)
316
+
317
+ traces = []
318
+ for row_id, data in checkpoint.items():
319
+ traces.append({
320
+ 'row_id': row_id,
321
+ 'question': data['question'],
322
+ 'answer': data['expected_answer'],
323
+ 'question_type': data.get('question_type', 'type_a'),
324
+ 'source': 'v19_train',
325
+ })
326
+
327
+ logger.info(f"Loaded {len(traces)} V19 traces")
328
+
329
+ type_counts = Counter(t['question_type'] for t in traces)
330
+ for qt, count in sorted(type_counts.items()):
331
+ logger.info(f" {qt}: {count}")
332
+
333
+ return traces
334
+
335
+
336
+ def load_test_augmentation(
337
+ checkpoint_path: str,
338
+ test_csv_path: str,
339
+ min_agreement: int = 3,
340
+ ) -> List[Dict]:
341
+ """Load high-confidence test predictions for GRPO training."""
342
+ import pandas as pd
343
+
344
+ logger.info(f"Loading test augmentation from {checkpoint_path}")
345
+
346
+ with open(checkpoint_path, 'r') as f:
347
+ checkpoint = json.load(f)
348
+
349
+ test_df = pd.read_csv(test_csv_path)
350
+ id_to_question = dict(zip(test_df['ID'], test_df['question']))
351
+
352
+ samples = []
353
+ for row_key, data in checkpoint.items():
354
+ question_id = data['ID']
355
+ responses = data['responses']
356
+
357
+ answers = [r['answer'] for r in responses if r.get('answer')]
358
+ if not answers:
359
+ continue
360
+
361
+ answer_counts = Counter(answers)
362
+ most_common_answer, count = answer_counts.most_common(1)[0]
363
+
364
+ if count < min_agreement:
365
+ continue
366
+
367
+ full_question = id_to_question.get(question_id, data.get('question', ''))
368
+ if not full_question:
369
+ continue
370
+
371
+ samples.append({
372
+ 'row_id': f"aug_{question_id}",
373
+ 'question': full_question,
374
+ 'answer': most_common_answer,
375
+ 'question_type': get_question_type(full_question),
376
+ 'source': 'augmentation',
377
+ })
378
+
379
+ logger.info(f"Loaded {len(samples)} augmentation samples (>={min_agreement}/4 agreement)")
380
+ return samples
381
+
382
+
383
+ # =============================================================================
384
+ # DATASET PREPARATION
385
+ # =============================================================================
386
+
387
+ def compute_sample_weights(samples: List[Dict]) -> List[float]:
388
+ """Compute per-sample weights for type-balanced sampling.
389
+
390
+ Weight = total / (n_types * count_for_type), so each type contributes
391
+ equally to expected samples drawn per epoch.
392
+ """
393
+ type_counts = Counter(s.get('question_type', 'type_a') for s in samples)
394
+ n_types = len(type_counts)
395
+ total = len(samples)
396
+
397
+ weights = []
398
+ for s in samples:
399
+ qt = s.get('question_type', 'type_a')
400
+ weights.append(total / (n_types * type_counts[qt]))
401
+
402
+ logger.info("Type-balanced sampling weights:")
403
+ for qt in sorted(type_counts):
404
+ w = total / (n_types * type_counts[qt])
405
+ logger.info(f" {qt}: {w:.2f}x ({type_counts[qt]} samples)")
406
+
407
+ return weights
408
+
409
+
410
+ def strip_raw_tables(question: str) -> str:
411
+ """Strip raw data tables from a telco question, keeping instructions + options.
412
+
413
+ Works for both Type A and Type B questions.
414
+ Must match the SFT training format exactly.
415
+ """
416
+ lines = question.split('\n')
417
+ preamble_lines = []
418
+ for line in lines:
419
+ if line.count('|') >= 3:
420
+ while preamble_lines and preamble_lines[-1].strip() == '':
421
+ preamble_lines.pop()
422
+ if preamble_lines and 'data as follows' in preamble_lines[-1].lower():
423
+ preamble_lines.pop()
424
+ break
425
+ preamble_lines.append(line)
426
+
427
+ result = '\n'.join(preamble_lines).strip()
428
+ return result if result else question
429
+
430
+
431
+ def compute_type_a_metrics_for_question(question: str):
432
+ """Compute Type A metrics block for a question. Returns formatted string or None."""
433
+ try:
434
+ drive_test, cells = parse_type_a_question(question)
435
+ if drive_test:
436
+ metrics = compute_all_metrics(question, drive_test, cells)
437
+ return format_metrics_block(metrics)
438
+ except Exception as e:
439
+ logger.debug(f"Failed to compute Type A metrics: {e}")
440
+ return None
441
+
442
+
443
+ def compute_type_b_metrics_for_question(question: str):
444
+ """Compute Type B metrics block for a question. Returns formatted string or None."""
445
+ try:
446
+ m = compute_type_b_metrics(question)
447
+ if m is not None:
448
+ return format_type_b_metrics_block(m)
449
+ except Exception as e:
450
+ logger.debug(f"Failed to compute Type B metrics: {e}")
451
+ return None
452
+
453
+
454
+ def prepare_grpo_dataset(
455
+ samples: List[Dict],
456
+ tokenizer,
457
+ ) -> Dataset:
458
+ """
459
+ Prepare dataset for GRPO training.
460
+
461
+ Pre-computes metrics and strips raw tables to match SFT training format.
462
+ GRPO needs:
463
+ - prompt: list of messages (system + user)
464
+ - answer: ground truth for reward computation
465
+ """
466
+ formatted = []
467
+ metrics_computed = 0
468
+ metrics_failed = 0
469
+
470
+ for sample in tqdm(samples, desc="Formatting prompts"):
471
+ question = sample['question']
472
+ answer = sample['answer']
473
+ question_type = sample.get('question_type', 'type_a')
474
+
475
+ # Select system prompt and compute metrics - must match SFT format
476
+ if question_type == 'type_b':
477
+ system_prompt = TYPE_B_SYSTEM_PROMPT
478
+ metrics_block = compute_type_b_metrics_for_question(question)
479
+ elif question_type == 'generic':
480
+ system_prompt = GENERIC_SYSTEM_PROMPT
481
+ metrics_block = None
482
+ else:
483
+ system_prompt = TELCO_SYSTEM_PROMPT
484
+ metrics_block = compute_type_a_metrics_for_question(question)
485
+
486
+ # Build user message matching SFT training format
487
+ if metrics_block:
488
+ question_preamble = strip_raw_tables(question)
489
+ user_content = f"## Pre-computed Metrics\n\n{metrics_block}\n\n## Question\n\n{question_preamble}"
490
+ metrics_computed += 1
491
+ else:
492
+ user_content = question
493
+ if question_type != 'generic':
494
+ metrics_failed += 1
495
+
496
+ prompt = [
497
+ {'role': 'system', 'content': system_prompt},
498
+ {'role': 'user', 'content': user_content},
499
+ ]
500
+
501
+ formatted.append({
502
+ 'prompt': prompt,
503
+ 'answer': answer,
504
+ 'row_id': sample.get('row_id', 'unknown'),
505
+ 'source': sample.get('source', 'unknown'),
506
+ })
507
+
508
+ logger.info(f"Metrics computed: {metrics_computed}, failed: {metrics_failed}")
509
+
510
+ dataset = Dataset.from_list(formatted)
511
+
512
+ # Analyze prompt lengths
513
+ logger.info("Analyzing prompt lengths...")
514
+ lengths = []
515
+ for i in range(min(50, len(dataset))):
516
+ text = tokenizer.apply_chat_template(
517
+ dataset[i]['prompt'],
518
+ tokenize=True,
519
+ add_generation_prompt=True,
520
+ )
521
+ lengths.append(len(text))
522
+
523
+ logger.info(f"Prompt length stats (first {len(lengths)} samples):")
524
+ logger.info(f" Min: {min(lengths)}")
525
+ logger.info(f" Max: {max(lengths)}")
526
+ logger.info(f" Mean: {sum(lengths)/len(lengths):.0f}")
527
+
528
+ return dataset
529
+
530
+
531
+ # =============================================================================
532
+ # LOAD SFT ADAPTER CONFIG
533
+ # =============================================================================
534
+
535
+ def load_sft_adapter_config(sft_model_path: str):
536
+ """
537
+ Load adapter configuration from SFT model.
538
+ Returns (rank, target_modules, lora_alpha) or defaults if not found.
539
+ """
540
+ from peft import PeftConfig
541
+ from huggingface_hub import hf_hub_download
542
+
543
+ try:
544
+ config = PeftConfig.from_pretrained(sft_model_path)
545
+ logger.info(f"Loaded SFT adapter config from: {sft_model_path}")
546
+ logger.info(f" rank (r): {config.r}")
547
+ logger.info(f" target_modules: {list(config.target_modules)}")
548
+ logger.info(f" lora_alpha: {config.lora_alpha}")
549
+ return config.r, list(config.target_modules), config.lora_alpha
550
+ except Exception as e:
551
+ logger.warning(f"Could not load PeftConfig: {e}")
552
+
553
+ try:
554
+ config_path = os.path.join(sft_model_path, "adapter_config.json")
555
+ if not os.path.exists(config_path):
556
+ config_path = hf_hub_download(
557
+ repo_id=sft_model_path,
558
+ filename="adapter_config.json",
559
+ )
560
+
561
+ with open(config_path, 'r') as f:
562
+ config = json.load(f)
563
+
564
+ rank = config.get('r', 32)
565
+ target_modules = config.get('target_modules', [
566
+ "q_proj", "k_proj", "v_proj", "o_proj",
567
+ "gate_proj", "up_proj", "down_proj",
568
+ ])
569
+ lora_alpha = config.get('lora_alpha', rank * 2)
570
+
571
+ logger.info(f"Loaded SFT adapter config from adapter_config.json")
572
+ logger.info(f" rank (r): {rank}")
573
+ logger.info(f" target_modules: {target_modules}")
574
+ logger.info(f" lora_alpha: {lora_alpha}")
575
+ return rank, target_modules, lora_alpha
576
+ except Exception as e:
577
+ logger.warning(f"Could not load adapter_config.json: {e}")
578
+
579
+ logger.warning("Using default LoRA config (r=32)")
580
+ return 32, [
581
+ "q_proj", "k_proj", "v_proj", "o_proj",
582
+ "gate_proj", "up_proj", "down_proj",
583
+ ], 64
584
+
585
+
586
+ # =============================================================================
587
+ # TRAINING
588
+ # =============================================================================
589
+
590
+ def train(
591
+ sft_model_path: str,
592
+ base_model: str,
593
+ train_checkpoint_path: str,
594
+ test_checkpoint_path: str,
595
+ test_csv_path: str,
596
+ output_dir: str,
597
+ hf_repo: str = None,
598
+ hf_token: str = None,
599
+ max_seq_length: int = 8192,
600
+ lora_rank: int = None,
601
+ max_steps: int = 100,
602
+ num_generations: int = 6,
603
+ learning_rate: float = 5e-6,
604
+ temperature: float = 1.0,
605
+ gradient_accumulation_steps: int = 4,
606
+ gpu_memory_utilization: float = 0.95,
607
+ min_agreement: int = 3,
608
+ use_augmentation: bool = True,
609
+ push_to_hub: bool = False,
610
+ merge_16bit: bool = False,
611
+ fast_inference: bool = True,
612
+ dry_run: bool = False,
613
+ seed: int = 42,
614
+ ):
615
+ """Main GRPO training function."""
616
+
617
+ logger.info("=" * 60)
618
+ logger.info("QWEN3-32B V19 GRPO TRAINING")
619
+ logger.info("Rewards: boxed_reward, think_tags_reward, accuracy_reward")
620
+ logger.info("=" * 60)
621
+
622
+ # =================================
623
+ # Load data
624
+ # =================================
625
+ train_traces = load_v19_traces(train_checkpoint_path)
626
+
627
+ augmentation_samples = []
628
+ if use_augmentation and os.path.exists(test_checkpoint_path):
629
+ augmentation_samples = load_test_augmentation(
630
+ test_checkpoint_path,
631
+ test_csv_path,
632
+ min_agreement=min_agreement,
633
+ )
634
+
635
+ all_samples = train_traces + augmentation_samples
636
+
637
+ logger.info(f"\nDataset summary:")
638
+ logger.info(f" V19 traces: {len(train_traces)}")
639
+ logger.info(f" Augmentation: {len(augmentation_samples)}")
640
+ logger.info(f" Total: {len(all_samples)}")
641
+
642
+ if len(all_samples) == 0:
643
+ logger.error("No samples found!")
644
+ return
645
+
646
+ # Analyze answer distribution
647
+ answer_counts = Counter(s['answer'] for s in all_samples)
648
+ logger.info(f"\nAnswer distribution ({len(answer_counts)} unique):")
649
+ for ans, cnt in sorted(answer_counts.items(), key=lambda x: -x[1])[:10]:
650
+ logger.info(f" {ans}: {cnt}")
651
+
652
+ if dry_run:
653
+ logger.info("\nDRY RUN - Data validation complete!")
654
+ logger.info("Sample prompts:")
655
+ for i, sample in enumerate(all_samples[:3]):
656
+ logger.info(f"\n--- Sample {i+1} ({sample.get('question_type', '?')}) ---")
657
+ logger.info(f"Question: {sample['question'][:200]}...")
658
+ logger.info(f"Answer: {sample['answer']}")
659
+ return
660
+
661
+ # =================================
662
+ # Load SFT adapter config
663
+ # =================================
664
+ sft_rank, sft_target_modules, sft_lora_alpha = load_sft_adapter_config(sft_model_path)
665
+
666
+ if lora_rank is None:
667
+ lora_rank = sft_rank
668
+ elif lora_rank != sft_rank:
669
+ logger.warning(f"CLI --lora-rank={lora_rank} differs from SFT rank={sft_rank}. Using CLI value.")
670
+
671
+ # =================================
672
+ # Load model
673
+ # =================================
674
+ from unsloth import is_bfloat16_supported
675
+
676
+ logger.info(f"\nLoading base model: {base_model}")
677
+ logger.info(f"SFT LoRA adapter: {sft_model_path}")
678
+ logger.info(f"Fast inference (vLLM): {fast_inference}")
679
+
680
+ from_pretrained_kwargs = {
681
+ "model_name": base_model,
682
+ "max_seq_length": max_seq_length,
683
+ "load_in_4bit": True,
684
+ "fast_inference": fast_inference,
685
+ }
686
+ if fast_inference:
687
+ from_pretrained_kwargs["max_lora_rank"] = lora_rank
688
+ from_pretrained_kwargs["gpu_memory_utilization"] = gpu_memory_utilization
689
+
690
+ model, tokenizer = FastLanguageModel.from_pretrained(**from_pretrained_kwargs)
691
+
692
+ logger.info(f"Setting up LoRA: rank={lora_rank}, target_modules={sft_target_modules}")
693
+
694
+ model = FastLanguageModel.get_peft_model(
695
+ model,
696
+ r=lora_rank,
697
+ target_modules=sft_target_modules,
698
+ lora_alpha=sft_lora_alpha,
699
+ use_gradient_checkpointing="unsloth",
700
+ random_state=seed,
701
+ )
702
+
703
+ # Load SFT LoRA weights
704
+ from peft import set_peft_model_state_dict
705
+ from safetensors.torch import load_file
706
+ from huggingface_hub import hf_hub_download
707
+
708
+ local_weights_path = os.path.join(sft_model_path, "adapter_model.safetensors")
709
+ if os.path.exists(local_weights_path):
710
+ sft_weights_path = local_weights_path
711
+ else:
712
+ try:
713
+ sft_weights_path = hf_hub_download(
714
+ repo_id=sft_model_path,
715
+ filename="adapter_model.safetensors",
716
+ )
717
+ except Exception as e:
718
+ logger.error(f"Could not download SFT weights: {e}")
719
+ raise RuntimeError("Failed to load SFT weights. Check --sft-model path.")
720
+
721
+ sft_state_dict = load_file(sft_weights_path)
722
+ logger.info(f"Loading {len(sft_state_dict)} weight tensors from SFT")
723
+
724
+ sft_keys = list(sft_state_dict.keys())
725
+ model_keys = [k for k in model.state_dict().keys() if 'lora' in k.lower()]
726
+ logger.info(f"SFT adapter key example: {sft_keys[0] if sft_keys else 'none'}")
727
+ logger.info(f"Model LoRA key example: {model_keys[0] if model_keys else 'none'}")
728
+
729
+ try:
730
+ set_peft_model_state_dict(model, sft_state_dict)
731
+ logger.info(f"Loaded SFT weights via set_peft_model_state_dict from: {sft_model_path}")
732
+ except Exception as e:
733
+ logger.warning(f"set_peft_model_state_dict failed: {e}, trying manual key mapping...")
734
+ fixed_state_dict = {}
735
+ for key, value in sft_state_dict.items():
736
+ new_key = key
737
+ for prefix in ['base_model.model.', 'base_model.']:
738
+ if new_key.startswith(prefix):
739
+ new_key = new_key[len(prefix):]
740
+ break
741
+ fixed_state_dict[new_key] = value
742
+ missing, unexpected = model.load_state_dict(fixed_state_dict, strict=False)
743
+ loaded = len(sft_state_dict) - len(unexpected)
744
+ logger.info(f"Manual loading: {loaded}/{len(sft_state_dict)} tensors loaded")
745
+ if unexpected:
746
+ logger.warning(f"Could not load {len(unexpected)} tensors (key mismatch)")
747
+
748
+ model.print_trainable_parameters()
749
+
750
+ # =================================
751
+ # Prepare dataset
752
+ # =================================
753
+ dataset = prepare_grpo_dataset(all_samples, tokenizer)
754
+ logger.info(f"\nGRPO dataset: {len(dataset)} samples")
755
+
756
+ logger.info("\nSample prompt:")
757
+ sample_text = tokenizer.apply_chat_template(
758
+ dataset[0]['prompt'],
759
+ tokenize=False,
760
+ add_generation_prompt=True,
761
+ )
762
+ logger.info(f"Length: {len(sample_text)} chars")
763
+ logger.info(f"Preview:\n{sample_text[:500]}...")
764
+
765
+ # =================================
766
+ # GRPO Configuration
767
+ # =================================
768
+ from vllm import SamplingParams
769
+ from trl import GRPOConfig, GRPOTrainer
770
+
771
+ # With pre-computed metrics, prompts are ~1650 tokens max (Type A).
772
+ # 2048 gives comfortable headroom without wasting seq budget on padding.
773
+ min_prompt_budget = 2048
774
+ max_prompt_length = min(min_prompt_budget, max_seq_length - 1024)
775
+ max_completion_length = max_seq_length - max_prompt_length
776
+
777
+ logger.info(f"Token budget: prompt={max_prompt_length}, completion={max_completion_length}")
778
+ if max_prompt_length < 3200:
779
+ logger.warning(
780
+ f"max_prompt_length={max_prompt_length} may truncate long prompts. "
781
+ f"Recommend --max-seq-length 4500 or higher."
782
+ )
783
+ if max_completion_length < 1500:
784
+ logger.warning(
785
+ f"max_completion_length={max_completion_length} limits reasoning space. "
786
+ f"Recommend --max-seq-length 5500 or higher."
787
+ )
788
+
789
+ vllm_sampling_params = SamplingParams(
790
+ min_p=0.1,
791
+ top_p=0.95,
792
+ top_k=50,
793
+ repetition_penalty=1.05,
794
+ seed=seed,
795
+ stop=[tokenizer.eos_token],
796
+ include_stop_str_in_output=True,
797
+ )
798
+
799
+ training_args = GRPOConfig(
800
+ output_dir=f"{output_dir}/checkpoints",
801
+ vllm_sampling_params=vllm_sampling_params,
802
+ temperature=temperature,
803
+ learning_rate=learning_rate,
804
+ weight_decay=0.001,
805
+ warmup_ratio=0.1,
806
+ lr_scheduler_type="linear",
807
+ optim="adamw_8bit",
808
+ logging_steps=1,
809
+ per_device_train_batch_size=1,
810
+ gradient_accumulation_steps=gradient_accumulation_steps,
811
+ num_generations=num_generations,
812
+ max_prompt_length=max_prompt_length,
813
+ max_completion_length=max_completion_length,
814
+ max_steps=max_steps,
815
+ max_grad_norm=1.0,
816
+ save_steps=max(50, max_steps // 2),
817
+ report_to="none",
818
+ seed=seed,
819
+ )
820
+
821
+ logger.info("\n" + "=" * 60)
822
+ logger.info("GRPO CONFIGURATION")
823
+ logger.info("=" * 60)
824
+ logger.info(f"SFT Model: {sft_model_path}")
825
+ logger.info(f"Max steps: {max_steps}")
826
+ logger.info(f"Num generations: {num_generations}")
827
+ logger.info(f"Learning rate: {learning_rate}")
828
+ logger.info(f"Temperature: {temperature}")
829
+ logger.info(f"Max prompt length: {max_prompt_length}")
830
+ logger.info(f"Max completion length: {max_completion_length}")
831
+ logger.info(f"Reward functions: boxed_reward, think_tags_reward, accuracy_reward")
832
+
833
+ # =================================
834
+ # Create trainer (exactly 3 rewards)
835
+ # =================================
836
+ # Compute per-sample weights for type-balanced sampling
837
+ sample_weights = compute_sample_weights(all_samples)
838
+
839
+ # Note: TypeBalancedGRPOTrainer with WeightedRandomSampler was removed because
840
+ # GRPO's dataloader has special requirements for num_generations grouping.
841
+ # Overriding get_train_dataloader breaks the reward reshaping.
842
+ # Type balancing for GRPO is handled via the dataset composition instead.
843
+
844
+ trainer = GRPOTrainer(
845
+ model=model,
846
+ processing_class=tokenizer,
847
+ reward_funcs=[
848
+ boxed_reward,
849
+ think_tags_reward,
850
+ accuracy_reward,
851
+ ],
852
+ args=training_args,
853
+ train_dataset=dataset,
854
+ )
855
+
856
+ logger.info("\n" + "=" * 60)
857
+ logger.info("STARTING GRPO TRAINING")
858
+ logger.info("=" * 60)
859
+
860
+ trainer.train()
861
+
862
+ # =================================
863
+ # Save model
864
+ # =================================
865
+ logger.info("\n" + "=" * 60)
866
+ logger.info("SAVING MODEL")
867
+ logger.info("=" * 60)
868
+
869
+ os.makedirs(output_dir, exist_ok=True)
870
+
871
+ lora_output_dir = f"{output_dir}/lora"
872
+ model.save_pretrained(lora_output_dir)
873
+ tokenizer.save_pretrained(lora_output_dir)
874
+ logger.info(f"LoRA adapter saved to: {lora_output_dir}")
875
+
876
+ config = {
877
+ 'base_model': base_model,
878
+ 'sft_model': sft_model_path,
879
+ 'lora_rank': lora_rank,
880
+ 'target_modules': sft_target_modules,
881
+ 'lora_alpha': sft_lora_alpha,
882
+ 'max_seq_length': max_seq_length,
883
+ 'max_steps': max_steps,
884
+ 'num_generations': num_generations,
885
+ 'learning_rate': learning_rate,
886
+ 'temperature': temperature,
887
+ 'min_agreement': min_agreement,
888
+ 'train_samples': len(train_traces),
889
+ 'augmentation_samples': len(augmentation_samples),
890
+ 'total_samples': len(dataset),
891
+ 'reward_functions': ['boxed_reward', 'think_tags_reward', 'accuracy_reward'],
892
+ }
893
+
894
+ with open(f"{output_dir}/grpo_config.json", 'w') as f:
895
+ json.dump(config, f, indent=2)
896
+
897
+ # =================================
898
+ # Merge to 16-bit (optional)
899
+ # =================================
900
+ merged_output_dir = None
901
+ if merge_16bit:
902
+ logger.info("\nMerging LoRA to 16-bit model...")
903
+ merged_output_dir = f"{output_dir}/merged_16bit"
904
+
905
+ from unsloth import FastLanguageModel as FLM
906
+
907
+ merge_model, merge_tokenizer = FLM.from_pretrained(
908
+ model_name=base_model,
909
+ max_seq_length=max_seq_length,
910
+ load_in_4bit=True,
911
+ fast_inference=False,
912
+ )
913
+
914
+ merge_model = FLM.get_peft_model(
915
+ merge_model,
916
+ r=lora_rank,
917
+ target_modules=sft_target_modules,
918
+ lora_alpha=sft_lora_alpha,
919
+ )
920
+
921
+ from safetensors.torch import load_file as load_safetensors
922
+ lora_weights_path = f"{lora_output_dir}/adapter_model.safetensors"
923
+ if os.path.exists(lora_weights_path):
924
+ state_dict = load_safetensors(lora_weights_path)
925
+ merge_model.load_state_dict(state_dict, strict=False)
926
+ logger.info(f"Loaded LoRA weights from {lora_weights_path}")
927
+
928
+ merge_model.save_pretrained_merged(
929
+ merged_output_dir,
930
+ merge_tokenizer,
931
+ save_method="merged_16bit",
932
+ )
933
+ logger.info(f"Merged model saved to: {merged_output_dir}")
934
+
935
+ # =================================
936
+ # Push to HuggingFace (optional)
937
+ # =================================
938
+ if push_to_hub and hf_repo:
939
+ logger.info(f"\nPushing to HuggingFace: {hf_repo}")
940
+
941
+ if hf_token:
942
+ from huggingface_hub import login
943
+ login(token=hf_token)
944
+
945
+ from huggingface_hub import HfApi
946
+ api = HfApi()
947
+
948
+ if merge_16bit and merged_output_dir:
949
+ logger.info("Pushing merged 16-bit model...")
950
+ api.create_repo(repo_id=hf_repo, exist_ok=True)
951
+ api.upload_folder(
952
+ folder_path=merged_output_dir,
953
+ repo_id=hf_repo,
954
+ repo_type="model",
955
+ commit_message="Upload V19 GRPO-trained merged model",
956
+ )
957
+ logger.info(f"Merged model pushed to: https://huggingface.co/{hf_repo}")
958
+ else:
959
+ lora_repo = f"{hf_repo}-lora" if not hf_repo.endswith("-lora") else hf_repo
960
+ logger.info(f"Pushing LoRA adapter to: {lora_repo}")
961
+ api.create_repo(repo_id=lora_repo, exist_ok=True)
962
+ api.upload_folder(
963
+ folder_path=lora_output_dir,
964
+ repo_id=lora_repo,
965
+ repo_type="model",
966
+ commit_message="Upload V19 GRPO-trained LoRA adapter",
967
+ )
968
+ logger.info(f"LoRA adapter pushed to: https://huggingface.co/{lora_repo}")
969
+
970
+ logger.info("\n" + "=" * 60)
971
+ logger.info("V19 GRPO TRAINING COMPLETE!")
972
+ logger.info("=" * 60)
973
+ logger.info(f"LoRA adapter: {lora_output_dir}")
974
+ if merged_output_dir:
975
+ logger.info(f"Merged model: {merged_output_dir}")
976
+ if push_to_hub and hf_repo:
977
+ logger.info(f"HuggingFace: https://huggingface.co/{hf_repo}")
978
+
979
+
980
+ # =============================================================================
981
+ # MAIN
982
+ # =============================================================================
983
+
984
+ def main():
985
+ parser = argparse.ArgumentParser(
986
+ description="Train Qwen3-32B with V19 GRPO (3 rewards: boxed, think tags, accuracy)",
987
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
988
+ )
989
+
990
+ # Model paths
991
+ parser.add_argument(
992
+ '--base-model', type=str,
993
+ default='unsloth/Qwen3-32B-bnb-4bit',
994
+ help='Base model (from HuggingFace)',
995
+ )
996
+ parser.add_argument(
997
+ '--sft-model', type=str, required=True,
998
+ help='Path or HF repo for V19 SFT LoRA adapter',
999
+ )
1000
+ parser.add_argument(
1001
+ '--output-dir', type=str,
1002
+ default='./outputs/qwen3-32b-v19-grpo',
1003
+ help='Output directory for GRPO model',
1004
+ )
1005
+
1006
+ # HuggingFace Hub
1007
+ parser.add_argument(
1008
+ '--hf-repo', type=str, default=None,
1009
+ help='HuggingFace repo to push model',
1010
+ )
1011
+ parser.add_argument(
1012
+ '--hf-token', type=str, default=None,
1013
+ help='HuggingFace token for pushing model',
1014
+ )
1015
+ parser.add_argument(
1016
+ '--push-to-hub', action='store_true',
1017
+ help='Push model to HuggingFace Hub after training',
1018
+ )
1019
+ parser.add_argument(
1020
+ '--merge-16bit', action='store_true',
1021
+ help='Merge LoRA into 16-bit model before pushing',
1022
+ )
1023
+
1024
+ # Data paths
1025
+ parser.add_argument(
1026
+ '--train-checkpoint', type=str,
1027
+ default='./outputs/traces_final/traces_final.json',
1028
+ help='Path to training traces JSON',
1029
+ )
1030
+ parser.add_argument(
1031
+ '--test-checkpoint', type=str,
1032
+ default='',
1033
+ help='(unused) Path to test predictions for augmentation',
1034
+ )
1035
+ parser.add_argument(
1036
+ '--test-csv', type=str,
1037
+ default='',
1038
+ help='(unused) Path to test CSV for full questions',
1039
+ )
1040
+
1041
+ # Model config
1042
+ parser.add_argument(
1043
+ '--max-seq-length', type=int, default=8192,
1044
+ help='Maximum sequence length',
1045
+ )
1046
+ parser.add_argument(
1047
+ '--lora-rank', type=int, default=None,
1048
+ help='LoRA rank (default: read from SFT adapter config)',
1049
+ )
1050
+
1051
+ # Training config
1052
+ parser.add_argument(
1053
+ '--max-steps', type=int, default=100,
1054
+ help='Maximum training steps',
1055
+ )
1056
+ parser.add_argument(
1057
+ '--num-generations', type=int, default=6,
1058
+ help='Number of completions per prompt',
1059
+ )
1060
+ parser.add_argument(
1061
+ '--learning-rate', type=float, default=5e-6,
1062
+ help='Learning rate',
1063
+ )
1064
+ parser.add_argument(
1065
+ '--temperature', type=float, default=1.0,
1066
+ help='Sampling temperature for generation',
1067
+ )
1068
+ parser.add_argument(
1069
+ '--gradient-accumulation-steps', type=int, default=4,
1070
+ help='Gradient accumulation steps',
1071
+ )
1072
+ parser.add_argument(
1073
+ '--gpu-memory-utilization', type=float, default=0.95,
1074
+ help='GPU memory utilization for vLLM',
1075
+ )
1076
+
1077
+ # Data config
1078
+ parser.add_argument(
1079
+ '--min-agreement', type=int, default=3, choices=[3, 4],
1080
+ help='Minimum agreement for augmentation samples',
1081
+ )
1082
+ parser.add_argument(
1083
+ '--no-augment', action='store_true',
1084
+ help='Disable test set augmentation',
1085
+ )
1086
+
1087
+ # Utility
1088
+ parser.add_argument(
1089
+ '--no-fast-inference', action='store_true',
1090
+ help='Disable vLLM fast inference',
1091
+ )
1092
+ parser.add_argument(
1093
+ '--dry-run', action='store_true',
1094
+ help='Validate data without training',
1095
+ )
1096
+ parser.add_argument(
1097
+ '--seed', type=int, default=42,
1098
+ help='Random seed',
1099
+ )
1100
+
1101
+ args = parser.parse_args()
1102
+
1103
+ train(
1104
+ sft_model_path=args.sft_model,
1105
+ base_model=args.base_model,
1106
+ train_checkpoint_path=args.train_checkpoint,
1107
+ test_checkpoint_path=args.test_checkpoint,
1108
+ test_csv_path=args.test_csv,
1109
+ output_dir=args.output_dir,
1110
+ hf_repo=args.hf_repo,
1111
+ hf_token=args.hf_token,
1112
+ max_seq_length=args.max_seq_length,
1113
+ lora_rank=args.lora_rank,
1114
+ max_steps=args.max_steps,
1115
+ num_generations=args.num_generations,
1116
+ learning_rate=args.learning_rate,
1117
+ temperature=args.temperature,
1118
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
1119
+ gpu_memory_utilization=args.gpu_memory_utilization,
1120
+ min_agreement=args.min_agreement,
1121
+ use_augmentation=not args.no_augment,
1122
+ push_to_hub=args.push_to_hub,
1123
+ merge_16bit=args.merge_16bit,
1124
+ fast_inference=not args.no_fast_inference,
1125
+ dry_run=args.dry_run,
1126
+ seed=args.seed,
1127
+ )
1128
+
1129
+
1130
+ if __name__ == "__main__":
1131
+ main()