Upload train_grpo_final.py with huggingface_hub
Browse files- train_grpo_final.py +1131 -0
train_grpo_final.py
ADDED
|
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train Qwen3-32B with GRPO on reasoning traces.
|
| 3 |
+
|
| 4 |
+
Exactly 3 reward functions:
|
| 5 |
+
1. boxed_reward: response contains \\boxed{...}
|
| 6 |
+
2. think_tags_reward: response contains <think>...</think>
|
| 7 |
+
3. accuracy_reward: extracted answer matches ground truth
|
| 8 |
+
|
| 9 |
+
Uses Unsloth for efficient training with vLLM fast inference.
|
| 10 |
+
Loads reasoning traces (2400 Type A from train.csv) with question-type-aware prompts.
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# Basic (SFT LoRA from HF)
|
| 14 |
+
python train_grpo_final.py --sft-model USERNAME/sft-lora
|
| 15 |
+
|
| 16 |
+
# With more steps
|
| 17 |
+
python train_grpo_final.py --sft-model USERNAME/sft-lora --max-steps 200
|
| 18 |
+
|
| 19 |
+
# Push merged model to HF after training
|
| 20 |
+
python train_grpo_final.py \\
|
| 21 |
+
--sft-model USERNAME/sft-lora \\
|
| 22 |
+
--push-to-hub --merge-16bit \\
|
| 23 |
+
--hf-repo USERNAME/grpo-final \\
|
| 24 |
+
--hf-token hf_xxx
|
| 25 |
+
|
| 26 |
+
# Dry run (validate data)
|
| 27 |
+
python train_grpo_final.py --sft-model ./path/to/sft --dry-run
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# Enable Unsloth's memory-efficient standby mode for vLLM
|
| 31 |
+
# Must be set BEFORE importing unsloth
|
| 32 |
+
import os
|
| 33 |
+
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
|
| 34 |
+
import json
|
| 35 |
+
import re
|
| 36 |
+
import argparse
|
| 37 |
+
import logging
|
| 38 |
+
from typing import Dict, List
|
| 39 |
+
from collections import Counter
|
| 40 |
+
|
| 41 |
+
from tqdm import tqdm
|
| 42 |
+
from datasets import Dataset
|
| 43 |
+
|
| 44 |
+
# Metric computation (same as SFT training)
|
| 45 |
+
from telco_utils import parse_type_a_question
|
| 46 |
+
from generate_traces_final import (
|
| 47 |
+
compute_all_metrics, format_metrics_block,
|
| 48 |
+
compute_type_b_metrics, format_type_b_metrics_block,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# =============================================================================
|
| 56 |
+
# PatchFastRL - Required for Unsloth + GRPOTrainer integration
|
| 57 |
+
# =============================================================================
|
| 58 |
+
from unsloth import FastLanguageModel, PatchFastRL
|
| 59 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# =============================================================================
|
| 63 |
+
# V19 SYSTEM PROMPTS
|
| 64 |
+
# =============================================================================
|
| 65 |
+
|
| 66 |
+
TELCO_SYSTEM_PROMPT = """You are a 5G network root cause classifier. You receive pre-computed metrics and a multiple-choice question. Walk through the decision rules below IN ORDER, show your work for each check, identify the root cause, then match it to the correct option label.
|
| 67 |
+
|
| 68 |
+
OUTPUT FORMAT (mandatory):
|
| 69 |
+
1. Wrap ALL reasoning inside <think>...</think> tags.
|
| 70 |
+
2. After </think>, output EXACTLY \\boxed{LABEL} where LABEL is the option (e.g. C3, 7, M5).
|
| 71 |
+
3. Do NOT write anything after \\boxed{LABEL}. No explanation, no period, no newline text.
|
| 72 |
+
4. Every response MUST end with \\boxed{LABEL}. Omitting it is a failure.
|
| 73 |
+
|
| 74 |
+
DECISION RULES (apply first matching rule):
|
| 75 |
+
|
| 76 |
+
TIER 1 - check in order, return first match:
|
| 77 |
+
1. max_speed > 40 -> "speed exceeds 40km/h"
|
| 78 |
+
2. max_distance_low_tp > 1.0 -> "coverage distance exceeds 1km" (overshooting)
|
| 79 |
+
3. handover_count >= 3 -> "frequent handovers"
|
| 80 |
+
4. avg_rb < 170 -> "average scheduled RBs below 160"
|
| 81 |
+
|
| 82 |
+
TIER 2 - C1 detection (if ANY sub-rule matches -> "downtilt too large"):
|
| 83 |
+
5a. min_rsrp < -90 AND pci_collision = no AND c4_interference < 3
|
| 84 |
+
5b. strong_neighbor_count < 0.5 AND serving_tilt >= 15
|
| 85 |
+
5c. pci_collision = yes AND strong_neighbor_count < 0.5
|
| 86 |
+
OVERRIDES on C1:
|
| 87 |
+
- post_ho_good_streak >= 2 -> "neighboring cell higher throughput" instead
|
| 88 |
+
- pci_collision_ratio > 0.70 -> "PCI mod 30 collision" instead
|
| 89 |
+
- avg_rsrp > -79 AND strong_neighbor_count > 1.0 -> "neighboring cell higher throughput" instead
|
| 90 |
+
|
| 91 |
+
TIER 3 - interference:
|
| 92 |
+
6. c4_interference >= 3 -> "overlapping coverage/interference"
|
| 93 |
+
SKIP if: (min_neighbor_diff / c4_interference) < -0.5 AND c4_interference < 12
|
| 94 |
+
|
| 95 |
+
TIER 4 - PCI collision (pci_collision = yes):
|
| 96 |
+
7. If pci_collision_ratio >= 1.0 -> "PCI mod 30 collision"
|
| 97 |
+
If pci_collision_ratio < 1.0:
|
| 98 |
+
- serving_tilt > 10 AND rsrp_trend > 0.4 -> "downtilt too large"
|
| 99 |
+
- else -> "neighboring cell higher throughput"
|
| 100 |
+
If avg_off_axis > 30:
|
| 101 |
+
- min_rsrp < -90 -> "downtilt too large" (with override checks from Tier 2)
|
| 102 |
+
- else -> "neighboring cell higher throughput" (with override checks from Tier 5)
|
| 103 |
+
|
| 104 |
+
TIER 5 - C1 vs C3 tiebreaker:
|
| 105 |
+
8. serving_tilt >= 28 AND avg_sinr >= 12 -> "neighboring cell higher throughput" (SINR gate)
|
| 106 |
+
serving_tilt >= 28 AND avg_sinr < 12 -> "downtilt too large"
|
| 107 |
+
serving_tilt < 12 -> "neighboring cell higher throughput"
|
| 108 |
+
avg_rsrp < -90 -> "downtilt too large"
|
| 109 |
+
avg_rsrp > -82 -> "neighboring cell higher throughput"
|
| 110 |
+
Low confidence (avg_rsrp -90 to -82) -> rescue rules:
|
| 111 |
+
R1: pci_collision_ratio >= 0.9 -> "PCI mod 30 collision"
|
| 112 |
+
R2: strong_neighbor_count < 0.8 -> "downtilt too large"
|
| 113 |
+
R3: c4_interference >= 3.0 -> "downtilt too large"
|
| 114 |
+
R4: default -> "neighboring cell higher throughput"
|
| 115 |
+
OVERRIDES if "neighboring cell" (high/medium confidence):
|
| 116 |
+
- pci_collision_ratio > 0.70 -> "PCI mod 30 collision"
|
| 117 |
+
- rsrp_change > 5 AND rsrp_trend > 0.5 AND nb_within_5db < 1.0 -> "downtilt too large"
|
| 118 |
+
- rsrp_recovery > 15 -> "downtilt too large"
|
| 119 |
+
- serving_tilt > 6 AND nb_within_5db < 1.0 -> "downtilt too large"
|
| 120 |
+
OVERRIDES if "downtilt" (high/medium confidence):
|
| 121 |
+
- pci_collision_ratio > 0.70 -> "PCI mod 30 collision"
|
| 122 |
+
- avg_rsrp > -79 AND strong_neighbor_count > 1.0 -> "neighboring cell higher throughput"
|
| 123 |
+
|
| 124 |
+
Show your reasoning inside <think> tags, checking each tier in order. Then match the identified root cause to the option that describes it and answer with EXACTLY \\boxed{LABEL}. You MUST always end your response with \\boxed{LABEL}.
|
| 125 |
+
Examples: \\boxed{C3}, \\boxed{7}, \\boxed{M5}"""
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
TYPE_B_SYSTEM_PROMPT = """You are a 5G drive test root cause analyzer. You receive pre-computed metrics and a multiple-choice question about throughput drops. Walk through the decision rules below IN ORDER, show your work for each check, identify the root cause, then match it to the correct option label.
|
| 129 |
+
|
| 130 |
+
OUTPUT FORMAT (mandatory):
|
| 131 |
+
1. Wrap ALL reasoning inside <think>...</think> tags.
|
| 132 |
+
2. After </think>, output EXACTLY \\boxed{LABEL} where LABEL is the option letter (e.g. A, D, G).
|
| 133 |
+
3. Do NOT write anything after \\boxed{LABEL}. No explanation, no period, no newline text.
|
| 134 |
+
4. Every response MUST end with \\boxed{LABEL}. Omitting it is a failure.
|
| 135 |
+
|
| 136 |
+
IMPORTANT: Options are SHUFFLED per question - identify the root cause FIRST, then find which option letter matches it.
|
| 137 |
+
|
| 138 |
+
DECISION RULES (apply first matching rule):
|
| 139 |
+
|
| 140 |
+
1. avg_cce_fail > 0.25 -> "PDCCH congestion" (I)
|
| 141 |
+
2. actual_handovers >= 3 -> "intra-freq threshold too low / ping-pong" (H)
|
| 142 |
+
3. ratio_a3_ho >= 3 AND a3_events >= 2, OR rrc_reestablish > 0 AND a3_events >= 1:
|
| 143 |
+
-> Check n1_in_config:
|
| 144 |
+
If n1_in_config = False -> "missing neighbor cell configuration" (E)
|
| 145 |
+
If n1_in_config = True -> "intra-freq threshold too high" (G)
|
| 146 |
+
4. rsrp_var_norm > 0.08 AND avg_rsrp > -95 -> "overlap coverage" (A)
|
| 147 |
+
5. avg_rsrp < -95 -> "weak coverage" (F)
|
| 148 |
+
|
| 149 |
+
PHY HEALTH ANALYSIS (if no rule above matches):
|
| 150 |
+
6. If phy_healthy_during_low_tp = True AND neighbors_within_3dB = 0 AND avg_sinr > 10:
|
| 151 |
+
-> "transport/server-side anomaly" (D)
|
| 152 |
+
Meaning: Radio link healthy during TP drops, bottleneck above PHY layer.
|
| 153 |
+
7. If phy_healthy_during_low_tp = False AND low_tp_avg_mcs < 12 AND neighbors_within_3dB >= 1:
|
| 154 |
+
-> "overlap coverage" (A)
|
| 155 |
+
Meaning: MCS crashes with strong neighbor present = interference/pilot pollution.
|
| 156 |
+
|
| 157 |
+
CONFIGURATION CHECK (if no rule above matches):
|
| 158 |
+
8. If inter_freq_ho = True AND a2_thld > -100 AND n_configured_neighbors >= 6:
|
| 159 |
+
-> "inter-freq HO threshold unreasonable" (B)
|
| 160 |
+
Meaning: Inter-frequency handover triggered with unreasonable A2 threshold.
|
| 161 |
+
|
| 162 |
+
Show your reasoning inside <think> tags. Then match the root cause to the option that describes it and answer with EXACTLY \\boxed{LABEL}. You MUST always end your response with \\boxed{LABEL}.
|
| 163 |
+
Examples: \\boxed{A}, \\boxed{D}, \\boxed{G}"""
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
GENERIC_SYSTEM_PROMPT = """You are an expert problem solver. Analyze questions carefully and select the correct answer.
|
| 167 |
+
|
| 168 |
+
IMPORTANT - Answer Format:
|
| 169 |
+
- Use the EXACT option number/label from the question
|
| 170 |
+
- Examples: \\boxed{2}, \\boxed{B}, \\boxed{72}
|
| 171 |
+
|
| 172 |
+
You must strictly output your reasoning process within <think>...</think> tags before the final answer."""
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# =============================================================================
|
| 176 |
+
# ANSWER NORMALIZATION
|
| 177 |
+
# =============================================================================
|
| 178 |
+
|
| 179 |
+
def normalize_answer(answer: str) -> str:
|
| 180 |
+
"""Normalize answer format for comparison. C1 -> 1, c1 -> 1, 1 -> 1."""
|
| 181 |
+
if not answer:
|
| 182 |
+
return ""
|
| 183 |
+
answer = answer.strip()
|
| 184 |
+
match = re.match(r'^[Cc](\d+)$', answer)
|
| 185 |
+
if match:
|
| 186 |
+
return match.group(1)
|
| 187 |
+
return answer
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# =============================================================================
|
| 191 |
+
# REWARD FUNCTIONS (exactly 3)
|
| 192 |
+
# =============================================================================
|
| 193 |
+
|
| 194 |
+
BOXED_PATTERN = re.compile(r'\\boxed\s*\{\s*([^}]+?)\s*\}')
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def boxed_reward(prompts, completions, **kwargs):
|
| 198 |
+
"""
|
| 199 |
+
Reward for \\boxed{} presence.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
+0.5 if \\boxed{...} with content is present
|
| 203 |
+
-0.5 if missing
|
| 204 |
+
"""
|
| 205 |
+
scores = []
|
| 206 |
+
for completion in completions:
|
| 207 |
+
response = completion[0]["content"]
|
| 208 |
+
if BOXED_PATTERN.search(response):
|
| 209 |
+
scores.append(0.5)
|
| 210 |
+
else:
|
| 211 |
+
scores.append(-0.5)
|
| 212 |
+
return scores
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def think_tags_reward(prompts, completions, **kwargs):
|
| 216 |
+
"""
|
| 217 |
+
Reward for <think>...</think> tags with non-trivial content.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
+1.0 if both tags present AND content >= 50 chars
|
| 221 |
+
-0.5 if tags present but content too short (degenerate)
|
| 222 |
+
-1.0 if either tag is missing
|
| 223 |
+
"""
|
| 224 |
+
scores = []
|
| 225 |
+
for completion in completions:
|
| 226 |
+
response = completion[0]["content"]
|
| 227 |
+
if '<think>' in response and '</think>' in response:
|
| 228 |
+
# Extract content between think tags
|
| 229 |
+
start = response.index('<think>') + len('<think>')
|
| 230 |
+
end = response.index('</think>')
|
| 231 |
+
think_content = response[start:end].strip()
|
| 232 |
+
if len(think_content) >= 200:
|
| 233 |
+
scores.append(1.0)
|
| 234 |
+
else:
|
| 235 |
+
scores.append(-0.5)
|
| 236 |
+
else:
|
| 237 |
+
scores.append(-1.0)
|
| 238 |
+
return scores
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def accuracy_reward(prompts, completions, answer, **kwargs):
|
| 242 |
+
"""
|
| 243 |
+
Reward for correct answer matching ground truth.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
+5.0 for exact match
|
| 247 |
+
+3.0 for match after normalization (C1 == 1)
|
| 248 |
+
-2.0 for wrong answer
|
| 249 |
+
-3.0 for no answer extracted
|
| 250 |
+
"""
|
| 251 |
+
scores = []
|
| 252 |
+
for completion, true_answer in zip(completions, answer):
|
| 253 |
+
response = completion[0]["content"]
|
| 254 |
+
|
| 255 |
+
match = BOXED_PATTERN.search(response)
|
| 256 |
+
if not match:
|
| 257 |
+
scores.append(-3.0)
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
pred = match.group(1).strip()
|
| 261 |
+
true = true_answer.strip()
|
| 262 |
+
|
| 263 |
+
if pred == true:
|
| 264 |
+
scores.append(5.0)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
pred_norm = normalize_answer(pred)
|
| 268 |
+
true_norm = normalize_answer(true)
|
| 269 |
+
if pred_norm == true_norm:
|
| 270 |
+
scores.append(3.0)
|
| 271 |
+
continue
|
| 272 |
+
|
| 273 |
+
scores.append(-2.0)
|
| 274 |
+
return scores
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# =============================================================================
|
| 278 |
+
# QUESTION TYPE DETECTION
|
| 279 |
+
# =============================================================================
|
| 280 |
+
|
| 281 |
+
def get_question_type(question: str, source_type: str = None) -> str:
|
| 282 |
+
"""Detect question type: 'type_a', 'type_b', or 'generic'.
|
| 283 |
+
|
| 284 |
+
Uses source_type from trace data when available, falls back to heuristics.
|
| 285 |
+
"""
|
| 286 |
+
if source_type and source_type in ('type_a', 'type_b', 'generic'):
|
| 287 |
+
return source_type
|
| 288 |
+
|
| 289 |
+
# Heuristic: Type B questions have drive test throughput drop analysis
|
| 290 |
+
if 'throughput drop' in question.lower() and 'drive test' in question.lower():
|
| 291 |
+
return 'type_b'
|
| 292 |
+
|
| 293 |
+
# Heuristic: Type A questions have telco data tables
|
| 294 |
+
if question.strip().startswith("Analyze the following question"):
|
| 295 |
+
if '|' in question and question.count('|') >= 4:
|
| 296 |
+
return 'type_a'
|
| 297 |
+
|
| 298 |
+
# Default to generic
|
| 299 |
+
return 'generic'
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# =============================================================================
|
| 303 |
+
# DATA LOADING
|
| 304 |
+
# =============================================================================
|
| 305 |
+
|
| 306 |
+
def load_v19_traces(checkpoint_path: str) -> List[Dict]:
|
| 307 |
+
"""Load V19 reasoning traces.
|
| 308 |
+
|
| 309 |
+
V19 traces have: question, expected_answer, reasoning_trace, question_type.
|
| 310 |
+
All traces are pre-validated (success=True equivalent).
|
| 311 |
+
"""
|
| 312 |
+
logger.info(f"Loading V19 traces from {checkpoint_path}")
|
| 313 |
+
|
| 314 |
+
with open(checkpoint_path, 'r') as f:
|
| 315 |
+
checkpoint = json.load(f)
|
| 316 |
+
|
| 317 |
+
traces = []
|
| 318 |
+
for row_id, data in checkpoint.items():
|
| 319 |
+
traces.append({
|
| 320 |
+
'row_id': row_id,
|
| 321 |
+
'question': data['question'],
|
| 322 |
+
'answer': data['expected_answer'],
|
| 323 |
+
'question_type': data.get('question_type', 'type_a'),
|
| 324 |
+
'source': 'v19_train',
|
| 325 |
+
})
|
| 326 |
+
|
| 327 |
+
logger.info(f"Loaded {len(traces)} V19 traces")
|
| 328 |
+
|
| 329 |
+
type_counts = Counter(t['question_type'] for t in traces)
|
| 330 |
+
for qt, count in sorted(type_counts.items()):
|
| 331 |
+
logger.info(f" {qt}: {count}")
|
| 332 |
+
|
| 333 |
+
return traces
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def load_test_augmentation(
|
| 337 |
+
checkpoint_path: str,
|
| 338 |
+
test_csv_path: str,
|
| 339 |
+
min_agreement: int = 3,
|
| 340 |
+
) -> List[Dict]:
|
| 341 |
+
"""Load high-confidence test predictions for GRPO training."""
|
| 342 |
+
import pandas as pd
|
| 343 |
+
|
| 344 |
+
logger.info(f"Loading test augmentation from {checkpoint_path}")
|
| 345 |
+
|
| 346 |
+
with open(checkpoint_path, 'r') as f:
|
| 347 |
+
checkpoint = json.load(f)
|
| 348 |
+
|
| 349 |
+
test_df = pd.read_csv(test_csv_path)
|
| 350 |
+
id_to_question = dict(zip(test_df['ID'], test_df['question']))
|
| 351 |
+
|
| 352 |
+
samples = []
|
| 353 |
+
for row_key, data in checkpoint.items():
|
| 354 |
+
question_id = data['ID']
|
| 355 |
+
responses = data['responses']
|
| 356 |
+
|
| 357 |
+
answers = [r['answer'] for r in responses if r.get('answer')]
|
| 358 |
+
if not answers:
|
| 359 |
+
continue
|
| 360 |
+
|
| 361 |
+
answer_counts = Counter(answers)
|
| 362 |
+
most_common_answer, count = answer_counts.most_common(1)[0]
|
| 363 |
+
|
| 364 |
+
if count < min_agreement:
|
| 365 |
+
continue
|
| 366 |
+
|
| 367 |
+
full_question = id_to_question.get(question_id, data.get('question', ''))
|
| 368 |
+
if not full_question:
|
| 369 |
+
continue
|
| 370 |
+
|
| 371 |
+
samples.append({
|
| 372 |
+
'row_id': f"aug_{question_id}",
|
| 373 |
+
'question': full_question,
|
| 374 |
+
'answer': most_common_answer,
|
| 375 |
+
'question_type': get_question_type(full_question),
|
| 376 |
+
'source': 'augmentation',
|
| 377 |
+
})
|
| 378 |
+
|
| 379 |
+
logger.info(f"Loaded {len(samples)} augmentation samples (>={min_agreement}/4 agreement)")
|
| 380 |
+
return samples
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
# =============================================================================
|
| 384 |
+
# DATASET PREPARATION
|
| 385 |
+
# =============================================================================
|
| 386 |
+
|
| 387 |
+
def compute_sample_weights(samples: List[Dict]) -> List[float]:
|
| 388 |
+
"""Compute per-sample weights for type-balanced sampling.
|
| 389 |
+
|
| 390 |
+
Weight = total / (n_types * count_for_type), so each type contributes
|
| 391 |
+
equally to expected samples drawn per epoch.
|
| 392 |
+
"""
|
| 393 |
+
type_counts = Counter(s.get('question_type', 'type_a') for s in samples)
|
| 394 |
+
n_types = len(type_counts)
|
| 395 |
+
total = len(samples)
|
| 396 |
+
|
| 397 |
+
weights = []
|
| 398 |
+
for s in samples:
|
| 399 |
+
qt = s.get('question_type', 'type_a')
|
| 400 |
+
weights.append(total / (n_types * type_counts[qt]))
|
| 401 |
+
|
| 402 |
+
logger.info("Type-balanced sampling weights:")
|
| 403 |
+
for qt in sorted(type_counts):
|
| 404 |
+
w = total / (n_types * type_counts[qt])
|
| 405 |
+
logger.info(f" {qt}: {w:.2f}x ({type_counts[qt]} samples)")
|
| 406 |
+
|
| 407 |
+
return weights
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def strip_raw_tables(question: str) -> str:
|
| 411 |
+
"""Strip raw data tables from a telco question, keeping instructions + options.
|
| 412 |
+
|
| 413 |
+
Works for both Type A and Type B questions.
|
| 414 |
+
Must match the SFT training format exactly.
|
| 415 |
+
"""
|
| 416 |
+
lines = question.split('\n')
|
| 417 |
+
preamble_lines = []
|
| 418 |
+
for line in lines:
|
| 419 |
+
if line.count('|') >= 3:
|
| 420 |
+
while preamble_lines and preamble_lines[-1].strip() == '':
|
| 421 |
+
preamble_lines.pop()
|
| 422 |
+
if preamble_lines and 'data as follows' in preamble_lines[-1].lower():
|
| 423 |
+
preamble_lines.pop()
|
| 424 |
+
break
|
| 425 |
+
preamble_lines.append(line)
|
| 426 |
+
|
| 427 |
+
result = '\n'.join(preamble_lines).strip()
|
| 428 |
+
return result if result else question
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def compute_type_a_metrics_for_question(question: str):
|
| 432 |
+
"""Compute Type A metrics block for a question. Returns formatted string or None."""
|
| 433 |
+
try:
|
| 434 |
+
drive_test, cells = parse_type_a_question(question)
|
| 435 |
+
if drive_test:
|
| 436 |
+
metrics = compute_all_metrics(question, drive_test, cells)
|
| 437 |
+
return format_metrics_block(metrics)
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.debug(f"Failed to compute Type A metrics: {e}")
|
| 440 |
+
return None
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def compute_type_b_metrics_for_question(question: str):
|
| 444 |
+
"""Compute Type B metrics block for a question. Returns formatted string or None."""
|
| 445 |
+
try:
|
| 446 |
+
m = compute_type_b_metrics(question)
|
| 447 |
+
if m is not None:
|
| 448 |
+
return format_type_b_metrics_block(m)
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.debug(f"Failed to compute Type B metrics: {e}")
|
| 451 |
+
return None
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def prepare_grpo_dataset(
|
| 455 |
+
samples: List[Dict],
|
| 456 |
+
tokenizer,
|
| 457 |
+
) -> Dataset:
|
| 458 |
+
"""
|
| 459 |
+
Prepare dataset for GRPO training.
|
| 460 |
+
|
| 461 |
+
Pre-computes metrics and strips raw tables to match SFT training format.
|
| 462 |
+
GRPO needs:
|
| 463 |
+
- prompt: list of messages (system + user)
|
| 464 |
+
- answer: ground truth for reward computation
|
| 465 |
+
"""
|
| 466 |
+
formatted = []
|
| 467 |
+
metrics_computed = 0
|
| 468 |
+
metrics_failed = 0
|
| 469 |
+
|
| 470 |
+
for sample in tqdm(samples, desc="Formatting prompts"):
|
| 471 |
+
question = sample['question']
|
| 472 |
+
answer = sample['answer']
|
| 473 |
+
question_type = sample.get('question_type', 'type_a')
|
| 474 |
+
|
| 475 |
+
# Select system prompt and compute metrics - must match SFT format
|
| 476 |
+
if question_type == 'type_b':
|
| 477 |
+
system_prompt = TYPE_B_SYSTEM_PROMPT
|
| 478 |
+
metrics_block = compute_type_b_metrics_for_question(question)
|
| 479 |
+
elif question_type == 'generic':
|
| 480 |
+
system_prompt = GENERIC_SYSTEM_PROMPT
|
| 481 |
+
metrics_block = None
|
| 482 |
+
else:
|
| 483 |
+
system_prompt = TELCO_SYSTEM_PROMPT
|
| 484 |
+
metrics_block = compute_type_a_metrics_for_question(question)
|
| 485 |
+
|
| 486 |
+
# Build user message matching SFT training format
|
| 487 |
+
if metrics_block:
|
| 488 |
+
question_preamble = strip_raw_tables(question)
|
| 489 |
+
user_content = f"## Pre-computed Metrics\n\n{metrics_block}\n\n## Question\n\n{question_preamble}"
|
| 490 |
+
metrics_computed += 1
|
| 491 |
+
else:
|
| 492 |
+
user_content = question
|
| 493 |
+
if question_type != 'generic':
|
| 494 |
+
metrics_failed += 1
|
| 495 |
+
|
| 496 |
+
prompt = [
|
| 497 |
+
{'role': 'system', 'content': system_prompt},
|
| 498 |
+
{'role': 'user', 'content': user_content},
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
formatted.append({
|
| 502 |
+
'prompt': prompt,
|
| 503 |
+
'answer': answer,
|
| 504 |
+
'row_id': sample.get('row_id', 'unknown'),
|
| 505 |
+
'source': sample.get('source', 'unknown'),
|
| 506 |
+
})
|
| 507 |
+
|
| 508 |
+
logger.info(f"Metrics computed: {metrics_computed}, failed: {metrics_failed}")
|
| 509 |
+
|
| 510 |
+
dataset = Dataset.from_list(formatted)
|
| 511 |
+
|
| 512 |
+
# Analyze prompt lengths
|
| 513 |
+
logger.info("Analyzing prompt lengths...")
|
| 514 |
+
lengths = []
|
| 515 |
+
for i in range(min(50, len(dataset))):
|
| 516 |
+
text = tokenizer.apply_chat_template(
|
| 517 |
+
dataset[i]['prompt'],
|
| 518 |
+
tokenize=True,
|
| 519 |
+
add_generation_prompt=True,
|
| 520 |
+
)
|
| 521 |
+
lengths.append(len(text))
|
| 522 |
+
|
| 523 |
+
logger.info(f"Prompt length stats (first {len(lengths)} samples):")
|
| 524 |
+
logger.info(f" Min: {min(lengths)}")
|
| 525 |
+
logger.info(f" Max: {max(lengths)}")
|
| 526 |
+
logger.info(f" Mean: {sum(lengths)/len(lengths):.0f}")
|
| 527 |
+
|
| 528 |
+
return dataset
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# =============================================================================
|
| 532 |
+
# LOAD SFT ADAPTER CONFIG
|
| 533 |
+
# =============================================================================
|
| 534 |
+
|
| 535 |
+
def load_sft_adapter_config(sft_model_path: str):
|
| 536 |
+
"""
|
| 537 |
+
Load adapter configuration from SFT model.
|
| 538 |
+
Returns (rank, target_modules, lora_alpha) or defaults if not found.
|
| 539 |
+
"""
|
| 540 |
+
from peft import PeftConfig
|
| 541 |
+
from huggingface_hub import hf_hub_download
|
| 542 |
+
|
| 543 |
+
try:
|
| 544 |
+
config = PeftConfig.from_pretrained(sft_model_path)
|
| 545 |
+
logger.info(f"Loaded SFT adapter config from: {sft_model_path}")
|
| 546 |
+
logger.info(f" rank (r): {config.r}")
|
| 547 |
+
logger.info(f" target_modules: {list(config.target_modules)}")
|
| 548 |
+
logger.info(f" lora_alpha: {config.lora_alpha}")
|
| 549 |
+
return config.r, list(config.target_modules), config.lora_alpha
|
| 550 |
+
except Exception as e:
|
| 551 |
+
logger.warning(f"Could not load PeftConfig: {e}")
|
| 552 |
+
|
| 553 |
+
try:
|
| 554 |
+
config_path = os.path.join(sft_model_path, "adapter_config.json")
|
| 555 |
+
if not os.path.exists(config_path):
|
| 556 |
+
config_path = hf_hub_download(
|
| 557 |
+
repo_id=sft_model_path,
|
| 558 |
+
filename="adapter_config.json",
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
with open(config_path, 'r') as f:
|
| 562 |
+
config = json.load(f)
|
| 563 |
+
|
| 564 |
+
rank = config.get('r', 32)
|
| 565 |
+
target_modules = config.get('target_modules', [
|
| 566 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 567 |
+
"gate_proj", "up_proj", "down_proj",
|
| 568 |
+
])
|
| 569 |
+
lora_alpha = config.get('lora_alpha', rank * 2)
|
| 570 |
+
|
| 571 |
+
logger.info(f"Loaded SFT adapter config from adapter_config.json")
|
| 572 |
+
logger.info(f" rank (r): {rank}")
|
| 573 |
+
logger.info(f" target_modules: {target_modules}")
|
| 574 |
+
logger.info(f" lora_alpha: {lora_alpha}")
|
| 575 |
+
return rank, target_modules, lora_alpha
|
| 576 |
+
except Exception as e:
|
| 577 |
+
logger.warning(f"Could not load adapter_config.json: {e}")
|
| 578 |
+
|
| 579 |
+
logger.warning("Using default LoRA config (r=32)")
|
| 580 |
+
return 32, [
|
| 581 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 582 |
+
"gate_proj", "up_proj", "down_proj",
|
| 583 |
+
], 64
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# =============================================================================
|
| 587 |
+
# TRAINING
|
| 588 |
+
# =============================================================================
|
| 589 |
+
|
| 590 |
+
def train(
|
| 591 |
+
sft_model_path: str,
|
| 592 |
+
base_model: str,
|
| 593 |
+
train_checkpoint_path: str,
|
| 594 |
+
test_checkpoint_path: str,
|
| 595 |
+
test_csv_path: str,
|
| 596 |
+
output_dir: str,
|
| 597 |
+
hf_repo: str = None,
|
| 598 |
+
hf_token: str = None,
|
| 599 |
+
max_seq_length: int = 8192,
|
| 600 |
+
lora_rank: int = None,
|
| 601 |
+
max_steps: int = 100,
|
| 602 |
+
num_generations: int = 6,
|
| 603 |
+
learning_rate: float = 5e-6,
|
| 604 |
+
temperature: float = 1.0,
|
| 605 |
+
gradient_accumulation_steps: int = 4,
|
| 606 |
+
gpu_memory_utilization: float = 0.95,
|
| 607 |
+
min_agreement: int = 3,
|
| 608 |
+
use_augmentation: bool = True,
|
| 609 |
+
push_to_hub: bool = False,
|
| 610 |
+
merge_16bit: bool = False,
|
| 611 |
+
fast_inference: bool = True,
|
| 612 |
+
dry_run: bool = False,
|
| 613 |
+
seed: int = 42,
|
| 614 |
+
):
|
| 615 |
+
"""Main GRPO training function."""
|
| 616 |
+
|
| 617 |
+
logger.info("=" * 60)
|
| 618 |
+
logger.info("QWEN3-32B V19 GRPO TRAINING")
|
| 619 |
+
logger.info("Rewards: boxed_reward, think_tags_reward, accuracy_reward")
|
| 620 |
+
logger.info("=" * 60)
|
| 621 |
+
|
| 622 |
+
# =================================
|
| 623 |
+
# Load data
|
| 624 |
+
# =================================
|
| 625 |
+
train_traces = load_v19_traces(train_checkpoint_path)
|
| 626 |
+
|
| 627 |
+
augmentation_samples = []
|
| 628 |
+
if use_augmentation and os.path.exists(test_checkpoint_path):
|
| 629 |
+
augmentation_samples = load_test_augmentation(
|
| 630 |
+
test_checkpoint_path,
|
| 631 |
+
test_csv_path,
|
| 632 |
+
min_agreement=min_agreement,
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
all_samples = train_traces + augmentation_samples
|
| 636 |
+
|
| 637 |
+
logger.info(f"\nDataset summary:")
|
| 638 |
+
logger.info(f" V19 traces: {len(train_traces)}")
|
| 639 |
+
logger.info(f" Augmentation: {len(augmentation_samples)}")
|
| 640 |
+
logger.info(f" Total: {len(all_samples)}")
|
| 641 |
+
|
| 642 |
+
if len(all_samples) == 0:
|
| 643 |
+
logger.error("No samples found!")
|
| 644 |
+
return
|
| 645 |
+
|
| 646 |
+
# Analyze answer distribution
|
| 647 |
+
answer_counts = Counter(s['answer'] for s in all_samples)
|
| 648 |
+
logger.info(f"\nAnswer distribution ({len(answer_counts)} unique):")
|
| 649 |
+
for ans, cnt in sorted(answer_counts.items(), key=lambda x: -x[1])[:10]:
|
| 650 |
+
logger.info(f" {ans}: {cnt}")
|
| 651 |
+
|
| 652 |
+
if dry_run:
|
| 653 |
+
logger.info("\nDRY RUN - Data validation complete!")
|
| 654 |
+
logger.info("Sample prompts:")
|
| 655 |
+
for i, sample in enumerate(all_samples[:3]):
|
| 656 |
+
logger.info(f"\n--- Sample {i+1} ({sample.get('question_type', '?')}) ---")
|
| 657 |
+
logger.info(f"Question: {sample['question'][:200]}...")
|
| 658 |
+
logger.info(f"Answer: {sample['answer']}")
|
| 659 |
+
return
|
| 660 |
+
|
| 661 |
+
# =================================
|
| 662 |
+
# Load SFT adapter config
|
| 663 |
+
# =================================
|
| 664 |
+
sft_rank, sft_target_modules, sft_lora_alpha = load_sft_adapter_config(sft_model_path)
|
| 665 |
+
|
| 666 |
+
if lora_rank is None:
|
| 667 |
+
lora_rank = sft_rank
|
| 668 |
+
elif lora_rank != sft_rank:
|
| 669 |
+
logger.warning(f"CLI --lora-rank={lora_rank} differs from SFT rank={sft_rank}. Using CLI value.")
|
| 670 |
+
|
| 671 |
+
# =================================
|
| 672 |
+
# Load model
|
| 673 |
+
# =================================
|
| 674 |
+
from unsloth import is_bfloat16_supported
|
| 675 |
+
|
| 676 |
+
logger.info(f"\nLoading base model: {base_model}")
|
| 677 |
+
logger.info(f"SFT LoRA adapter: {sft_model_path}")
|
| 678 |
+
logger.info(f"Fast inference (vLLM): {fast_inference}")
|
| 679 |
+
|
| 680 |
+
from_pretrained_kwargs = {
|
| 681 |
+
"model_name": base_model,
|
| 682 |
+
"max_seq_length": max_seq_length,
|
| 683 |
+
"load_in_4bit": True,
|
| 684 |
+
"fast_inference": fast_inference,
|
| 685 |
+
}
|
| 686 |
+
if fast_inference:
|
| 687 |
+
from_pretrained_kwargs["max_lora_rank"] = lora_rank
|
| 688 |
+
from_pretrained_kwargs["gpu_memory_utilization"] = gpu_memory_utilization
|
| 689 |
+
|
| 690 |
+
model, tokenizer = FastLanguageModel.from_pretrained(**from_pretrained_kwargs)
|
| 691 |
+
|
| 692 |
+
logger.info(f"Setting up LoRA: rank={lora_rank}, target_modules={sft_target_modules}")
|
| 693 |
+
|
| 694 |
+
model = FastLanguageModel.get_peft_model(
|
| 695 |
+
model,
|
| 696 |
+
r=lora_rank,
|
| 697 |
+
target_modules=sft_target_modules,
|
| 698 |
+
lora_alpha=sft_lora_alpha,
|
| 699 |
+
use_gradient_checkpointing="unsloth",
|
| 700 |
+
random_state=seed,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# Load SFT LoRA weights
|
| 704 |
+
from peft import set_peft_model_state_dict
|
| 705 |
+
from safetensors.torch import load_file
|
| 706 |
+
from huggingface_hub import hf_hub_download
|
| 707 |
+
|
| 708 |
+
local_weights_path = os.path.join(sft_model_path, "adapter_model.safetensors")
|
| 709 |
+
if os.path.exists(local_weights_path):
|
| 710 |
+
sft_weights_path = local_weights_path
|
| 711 |
+
else:
|
| 712 |
+
try:
|
| 713 |
+
sft_weights_path = hf_hub_download(
|
| 714 |
+
repo_id=sft_model_path,
|
| 715 |
+
filename="adapter_model.safetensors",
|
| 716 |
+
)
|
| 717 |
+
except Exception as e:
|
| 718 |
+
logger.error(f"Could not download SFT weights: {e}")
|
| 719 |
+
raise RuntimeError("Failed to load SFT weights. Check --sft-model path.")
|
| 720 |
+
|
| 721 |
+
sft_state_dict = load_file(sft_weights_path)
|
| 722 |
+
logger.info(f"Loading {len(sft_state_dict)} weight tensors from SFT")
|
| 723 |
+
|
| 724 |
+
sft_keys = list(sft_state_dict.keys())
|
| 725 |
+
model_keys = [k for k in model.state_dict().keys() if 'lora' in k.lower()]
|
| 726 |
+
logger.info(f"SFT adapter key example: {sft_keys[0] if sft_keys else 'none'}")
|
| 727 |
+
logger.info(f"Model LoRA key example: {model_keys[0] if model_keys else 'none'}")
|
| 728 |
+
|
| 729 |
+
try:
|
| 730 |
+
set_peft_model_state_dict(model, sft_state_dict)
|
| 731 |
+
logger.info(f"Loaded SFT weights via set_peft_model_state_dict from: {sft_model_path}")
|
| 732 |
+
except Exception as e:
|
| 733 |
+
logger.warning(f"set_peft_model_state_dict failed: {e}, trying manual key mapping...")
|
| 734 |
+
fixed_state_dict = {}
|
| 735 |
+
for key, value in sft_state_dict.items():
|
| 736 |
+
new_key = key
|
| 737 |
+
for prefix in ['base_model.model.', 'base_model.']:
|
| 738 |
+
if new_key.startswith(prefix):
|
| 739 |
+
new_key = new_key[len(prefix):]
|
| 740 |
+
break
|
| 741 |
+
fixed_state_dict[new_key] = value
|
| 742 |
+
missing, unexpected = model.load_state_dict(fixed_state_dict, strict=False)
|
| 743 |
+
loaded = len(sft_state_dict) - len(unexpected)
|
| 744 |
+
logger.info(f"Manual loading: {loaded}/{len(sft_state_dict)} tensors loaded")
|
| 745 |
+
if unexpected:
|
| 746 |
+
logger.warning(f"Could not load {len(unexpected)} tensors (key mismatch)")
|
| 747 |
+
|
| 748 |
+
model.print_trainable_parameters()
|
| 749 |
+
|
| 750 |
+
# =================================
|
| 751 |
+
# Prepare dataset
|
| 752 |
+
# =================================
|
| 753 |
+
dataset = prepare_grpo_dataset(all_samples, tokenizer)
|
| 754 |
+
logger.info(f"\nGRPO dataset: {len(dataset)} samples")
|
| 755 |
+
|
| 756 |
+
logger.info("\nSample prompt:")
|
| 757 |
+
sample_text = tokenizer.apply_chat_template(
|
| 758 |
+
dataset[0]['prompt'],
|
| 759 |
+
tokenize=False,
|
| 760 |
+
add_generation_prompt=True,
|
| 761 |
+
)
|
| 762 |
+
logger.info(f"Length: {len(sample_text)} chars")
|
| 763 |
+
logger.info(f"Preview:\n{sample_text[:500]}...")
|
| 764 |
+
|
| 765 |
+
# =================================
|
| 766 |
+
# GRPO Configuration
|
| 767 |
+
# =================================
|
| 768 |
+
from vllm import SamplingParams
|
| 769 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 770 |
+
|
| 771 |
+
# With pre-computed metrics, prompts are ~1650 tokens max (Type A).
|
| 772 |
+
# 2048 gives comfortable headroom without wasting seq budget on padding.
|
| 773 |
+
min_prompt_budget = 2048
|
| 774 |
+
max_prompt_length = min(min_prompt_budget, max_seq_length - 1024)
|
| 775 |
+
max_completion_length = max_seq_length - max_prompt_length
|
| 776 |
+
|
| 777 |
+
logger.info(f"Token budget: prompt={max_prompt_length}, completion={max_completion_length}")
|
| 778 |
+
if max_prompt_length < 3200:
|
| 779 |
+
logger.warning(
|
| 780 |
+
f"max_prompt_length={max_prompt_length} may truncate long prompts. "
|
| 781 |
+
f"Recommend --max-seq-length 4500 or higher."
|
| 782 |
+
)
|
| 783 |
+
if max_completion_length < 1500:
|
| 784 |
+
logger.warning(
|
| 785 |
+
f"max_completion_length={max_completion_length} limits reasoning space. "
|
| 786 |
+
f"Recommend --max-seq-length 5500 or higher."
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
vllm_sampling_params = SamplingParams(
|
| 790 |
+
min_p=0.1,
|
| 791 |
+
top_p=0.95,
|
| 792 |
+
top_k=50,
|
| 793 |
+
repetition_penalty=1.05,
|
| 794 |
+
seed=seed,
|
| 795 |
+
stop=[tokenizer.eos_token],
|
| 796 |
+
include_stop_str_in_output=True,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
training_args = GRPOConfig(
|
| 800 |
+
output_dir=f"{output_dir}/checkpoints",
|
| 801 |
+
vllm_sampling_params=vllm_sampling_params,
|
| 802 |
+
temperature=temperature,
|
| 803 |
+
learning_rate=learning_rate,
|
| 804 |
+
weight_decay=0.001,
|
| 805 |
+
warmup_ratio=0.1,
|
| 806 |
+
lr_scheduler_type="linear",
|
| 807 |
+
optim="adamw_8bit",
|
| 808 |
+
logging_steps=1,
|
| 809 |
+
per_device_train_batch_size=1,
|
| 810 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 811 |
+
num_generations=num_generations,
|
| 812 |
+
max_prompt_length=max_prompt_length,
|
| 813 |
+
max_completion_length=max_completion_length,
|
| 814 |
+
max_steps=max_steps,
|
| 815 |
+
max_grad_norm=1.0,
|
| 816 |
+
save_steps=max(50, max_steps // 2),
|
| 817 |
+
report_to="none",
|
| 818 |
+
seed=seed,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
logger.info("\n" + "=" * 60)
|
| 822 |
+
logger.info("GRPO CONFIGURATION")
|
| 823 |
+
logger.info("=" * 60)
|
| 824 |
+
logger.info(f"SFT Model: {sft_model_path}")
|
| 825 |
+
logger.info(f"Max steps: {max_steps}")
|
| 826 |
+
logger.info(f"Num generations: {num_generations}")
|
| 827 |
+
logger.info(f"Learning rate: {learning_rate}")
|
| 828 |
+
logger.info(f"Temperature: {temperature}")
|
| 829 |
+
logger.info(f"Max prompt length: {max_prompt_length}")
|
| 830 |
+
logger.info(f"Max completion length: {max_completion_length}")
|
| 831 |
+
logger.info(f"Reward functions: boxed_reward, think_tags_reward, accuracy_reward")
|
| 832 |
+
|
| 833 |
+
# =================================
|
| 834 |
+
# Create trainer (exactly 3 rewards)
|
| 835 |
+
# =================================
|
| 836 |
+
# Compute per-sample weights for type-balanced sampling
|
| 837 |
+
sample_weights = compute_sample_weights(all_samples)
|
| 838 |
+
|
| 839 |
+
# Note: TypeBalancedGRPOTrainer with WeightedRandomSampler was removed because
|
| 840 |
+
# GRPO's dataloader has special requirements for num_generations grouping.
|
| 841 |
+
# Overriding get_train_dataloader breaks the reward reshaping.
|
| 842 |
+
# Type balancing for GRPO is handled via the dataset composition instead.
|
| 843 |
+
|
| 844 |
+
trainer = GRPOTrainer(
|
| 845 |
+
model=model,
|
| 846 |
+
processing_class=tokenizer,
|
| 847 |
+
reward_funcs=[
|
| 848 |
+
boxed_reward,
|
| 849 |
+
think_tags_reward,
|
| 850 |
+
accuracy_reward,
|
| 851 |
+
],
|
| 852 |
+
args=training_args,
|
| 853 |
+
train_dataset=dataset,
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
logger.info("\n" + "=" * 60)
|
| 857 |
+
logger.info("STARTING GRPO TRAINING")
|
| 858 |
+
logger.info("=" * 60)
|
| 859 |
+
|
| 860 |
+
trainer.train()
|
| 861 |
+
|
| 862 |
+
# =================================
|
| 863 |
+
# Save model
|
| 864 |
+
# =================================
|
| 865 |
+
logger.info("\n" + "=" * 60)
|
| 866 |
+
logger.info("SAVING MODEL")
|
| 867 |
+
logger.info("=" * 60)
|
| 868 |
+
|
| 869 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 870 |
+
|
| 871 |
+
lora_output_dir = f"{output_dir}/lora"
|
| 872 |
+
model.save_pretrained(lora_output_dir)
|
| 873 |
+
tokenizer.save_pretrained(lora_output_dir)
|
| 874 |
+
logger.info(f"LoRA adapter saved to: {lora_output_dir}")
|
| 875 |
+
|
| 876 |
+
config = {
|
| 877 |
+
'base_model': base_model,
|
| 878 |
+
'sft_model': sft_model_path,
|
| 879 |
+
'lora_rank': lora_rank,
|
| 880 |
+
'target_modules': sft_target_modules,
|
| 881 |
+
'lora_alpha': sft_lora_alpha,
|
| 882 |
+
'max_seq_length': max_seq_length,
|
| 883 |
+
'max_steps': max_steps,
|
| 884 |
+
'num_generations': num_generations,
|
| 885 |
+
'learning_rate': learning_rate,
|
| 886 |
+
'temperature': temperature,
|
| 887 |
+
'min_agreement': min_agreement,
|
| 888 |
+
'train_samples': len(train_traces),
|
| 889 |
+
'augmentation_samples': len(augmentation_samples),
|
| 890 |
+
'total_samples': len(dataset),
|
| 891 |
+
'reward_functions': ['boxed_reward', 'think_tags_reward', 'accuracy_reward'],
|
| 892 |
+
}
|
| 893 |
+
|
| 894 |
+
with open(f"{output_dir}/grpo_config.json", 'w') as f:
|
| 895 |
+
json.dump(config, f, indent=2)
|
| 896 |
+
|
| 897 |
+
# =================================
|
| 898 |
+
# Merge to 16-bit (optional)
|
| 899 |
+
# =================================
|
| 900 |
+
merged_output_dir = None
|
| 901 |
+
if merge_16bit:
|
| 902 |
+
logger.info("\nMerging LoRA to 16-bit model...")
|
| 903 |
+
merged_output_dir = f"{output_dir}/merged_16bit"
|
| 904 |
+
|
| 905 |
+
from unsloth import FastLanguageModel as FLM
|
| 906 |
+
|
| 907 |
+
merge_model, merge_tokenizer = FLM.from_pretrained(
|
| 908 |
+
model_name=base_model,
|
| 909 |
+
max_seq_length=max_seq_length,
|
| 910 |
+
load_in_4bit=True,
|
| 911 |
+
fast_inference=False,
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
merge_model = FLM.get_peft_model(
|
| 915 |
+
merge_model,
|
| 916 |
+
r=lora_rank,
|
| 917 |
+
target_modules=sft_target_modules,
|
| 918 |
+
lora_alpha=sft_lora_alpha,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
from safetensors.torch import load_file as load_safetensors
|
| 922 |
+
lora_weights_path = f"{lora_output_dir}/adapter_model.safetensors"
|
| 923 |
+
if os.path.exists(lora_weights_path):
|
| 924 |
+
state_dict = load_safetensors(lora_weights_path)
|
| 925 |
+
merge_model.load_state_dict(state_dict, strict=False)
|
| 926 |
+
logger.info(f"Loaded LoRA weights from {lora_weights_path}")
|
| 927 |
+
|
| 928 |
+
merge_model.save_pretrained_merged(
|
| 929 |
+
merged_output_dir,
|
| 930 |
+
merge_tokenizer,
|
| 931 |
+
save_method="merged_16bit",
|
| 932 |
+
)
|
| 933 |
+
logger.info(f"Merged model saved to: {merged_output_dir}")
|
| 934 |
+
|
| 935 |
+
# =================================
|
| 936 |
+
# Push to HuggingFace (optional)
|
| 937 |
+
# =================================
|
| 938 |
+
if push_to_hub and hf_repo:
|
| 939 |
+
logger.info(f"\nPushing to HuggingFace: {hf_repo}")
|
| 940 |
+
|
| 941 |
+
if hf_token:
|
| 942 |
+
from huggingface_hub import login
|
| 943 |
+
login(token=hf_token)
|
| 944 |
+
|
| 945 |
+
from huggingface_hub import HfApi
|
| 946 |
+
api = HfApi()
|
| 947 |
+
|
| 948 |
+
if merge_16bit and merged_output_dir:
|
| 949 |
+
logger.info("Pushing merged 16-bit model...")
|
| 950 |
+
api.create_repo(repo_id=hf_repo, exist_ok=True)
|
| 951 |
+
api.upload_folder(
|
| 952 |
+
folder_path=merged_output_dir,
|
| 953 |
+
repo_id=hf_repo,
|
| 954 |
+
repo_type="model",
|
| 955 |
+
commit_message="Upload V19 GRPO-trained merged model",
|
| 956 |
+
)
|
| 957 |
+
logger.info(f"Merged model pushed to: https://huggingface.co/{hf_repo}")
|
| 958 |
+
else:
|
| 959 |
+
lora_repo = f"{hf_repo}-lora" if not hf_repo.endswith("-lora") else hf_repo
|
| 960 |
+
logger.info(f"Pushing LoRA adapter to: {lora_repo}")
|
| 961 |
+
api.create_repo(repo_id=lora_repo, exist_ok=True)
|
| 962 |
+
api.upload_folder(
|
| 963 |
+
folder_path=lora_output_dir,
|
| 964 |
+
repo_id=lora_repo,
|
| 965 |
+
repo_type="model",
|
| 966 |
+
commit_message="Upload V19 GRPO-trained LoRA adapter",
|
| 967 |
+
)
|
| 968 |
+
logger.info(f"LoRA adapter pushed to: https://huggingface.co/{lora_repo}")
|
| 969 |
+
|
| 970 |
+
logger.info("\n" + "=" * 60)
|
| 971 |
+
logger.info("V19 GRPO TRAINING COMPLETE!")
|
| 972 |
+
logger.info("=" * 60)
|
| 973 |
+
logger.info(f"LoRA adapter: {lora_output_dir}")
|
| 974 |
+
if merged_output_dir:
|
| 975 |
+
logger.info(f"Merged model: {merged_output_dir}")
|
| 976 |
+
if push_to_hub and hf_repo:
|
| 977 |
+
logger.info(f"HuggingFace: https://huggingface.co/{hf_repo}")
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
# =============================================================================
|
| 981 |
+
# MAIN
|
| 982 |
+
# =============================================================================
|
| 983 |
+
|
| 984 |
+
def main():
|
| 985 |
+
parser = argparse.ArgumentParser(
|
| 986 |
+
description="Train Qwen3-32B with V19 GRPO (3 rewards: boxed, think tags, accuracy)",
|
| 987 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
# Model paths
|
| 991 |
+
parser.add_argument(
|
| 992 |
+
'--base-model', type=str,
|
| 993 |
+
default='unsloth/Qwen3-32B-bnb-4bit',
|
| 994 |
+
help='Base model (from HuggingFace)',
|
| 995 |
+
)
|
| 996 |
+
parser.add_argument(
|
| 997 |
+
'--sft-model', type=str, required=True,
|
| 998 |
+
help='Path or HF repo for V19 SFT LoRA adapter',
|
| 999 |
+
)
|
| 1000 |
+
parser.add_argument(
|
| 1001 |
+
'--output-dir', type=str,
|
| 1002 |
+
default='./outputs/qwen3-32b-v19-grpo',
|
| 1003 |
+
help='Output directory for GRPO model',
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
# HuggingFace Hub
|
| 1007 |
+
parser.add_argument(
|
| 1008 |
+
'--hf-repo', type=str, default=None,
|
| 1009 |
+
help='HuggingFace repo to push model',
|
| 1010 |
+
)
|
| 1011 |
+
parser.add_argument(
|
| 1012 |
+
'--hf-token', type=str, default=None,
|
| 1013 |
+
help='HuggingFace token for pushing model',
|
| 1014 |
+
)
|
| 1015 |
+
parser.add_argument(
|
| 1016 |
+
'--push-to-hub', action='store_true',
|
| 1017 |
+
help='Push model to HuggingFace Hub after training',
|
| 1018 |
+
)
|
| 1019 |
+
parser.add_argument(
|
| 1020 |
+
'--merge-16bit', action='store_true',
|
| 1021 |
+
help='Merge LoRA into 16-bit model before pushing',
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
# Data paths
|
| 1025 |
+
parser.add_argument(
|
| 1026 |
+
'--train-checkpoint', type=str,
|
| 1027 |
+
default='./outputs/traces_final/traces_final.json',
|
| 1028 |
+
help='Path to training traces JSON',
|
| 1029 |
+
)
|
| 1030 |
+
parser.add_argument(
|
| 1031 |
+
'--test-checkpoint', type=str,
|
| 1032 |
+
default='',
|
| 1033 |
+
help='(unused) Path to test predictions for augmentation',
|
| 1034 |
+
)
|
| 1035 |
+
parser.add_argument(
|
| 1036 |
+
'--test-csv', type=str,
|
| 1037 |
+
default='',
|
| 1038 |
+
help='(unused) Path to test CSV for full questions',
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# Model config
|
| 1042 |
+
parser.add_argument(
|
| 1043 |
+
'--max-seq-length', type=int, default=8192,
|
| 1044 |
+
help='Maximum sequence length',
|
| 1045 |
+
)
|
| 1046 |
+
parser.add_argument(
|
| 1047 |
+
'--lora-rank', type=int, default=None,
|
| 1048 |
+
help='LoRA rank (default: read from SFT adapter config)',
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
# Training config
|
| 1052 |
+
parser.add_argument(
|
| 1053 |
+
'--max-steps', type=int, default=100,
|
| 1054 |
+
help='Maximum training steps',
|
| 1055 |
+
)
|
| 1056 |
+
parser.add_argument(
|
| 1057 |
+
'--num-generations', type=int, default=6,
|
| 1058 |
+
help='Number of completions per prompt',
|
| 1059 |
+
)
|
| 1060 |
+
parser.add_argument(
|
| 1061 |
+
'--learning-rate', type=float, default=5e-6,
|
| 1062 |
+
help='Learning rate',
|
| 1063 |
+
)
|
| 1064 |
+
parser.add_argument(
|
| 1065 |
+
'--temperature', type=float, default=1.0,
|
| 1066 |
+
help='Sampling temperature for generation',
|
| 1067 |
+
)
|
| 1068 |
+
parser.add_argument(
|
| 1069 |
+
'--gradient-accumulation-steps', type=int, default=4,
|
| 1070 |
+
help='Gradient accumulation steps',
|
| 1071 |
+
)
|
| 1072 |
+
parser.add_argument(
|
| 1073 |
+
'--gpu-memory-utilization', type=float, default=0.95,
|
| 1074 |
+
help='GPU memory utilization for vLLM',
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
# Data config
|
| 1078 |
+
parser.add_argument(
|
| 1079 |
+
'--min-agreement', type=int, default=3, choices=[3, 4],
|
| 1080 |
+
help='Minimum agreement for augmentation samples',
|
| 1081 |
+
)
|
| 1082 |
+
parser.add_argument(
|
| 1083 |
+
'--no-augment', action='store_true',
|
| 1084 |
+
help='Disable test set augmentation',
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# Utility
|
| 1088 |
+
parser.add_argument(
|
| 1089 |
+
'--no-fast-inference', action='store_true',
|
| 1090 |
+
help='Disable vLLM fast inference',
|
| 1091 |
+
)
|
| 1092 |
+
parser.add_argument(
|
| 1093 |
+
'--dry-run', action='store_true',
|
| 1094 |
+
help='Validate data without training',
|
| 1095 |
+
)
|
| 1096 |
+
parser.add_argument(
|
| 1097 |
+
'--seed', type=int, default=42,
|
| 1098 |
+
help='Random seed',
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
args = parser.parse_args()
|
| 1102 |
+
|
| 1103 |
+
train(
|
| 1104 |
+
sft_model_path=args.sft_model,
|
| 1105 |
+
base_model=args.base_model,
|
| 1106 |
+
train_checkpoint_path=args.train_checkpoint,
|
| 1107 |
+
test_checkpoint_path=args.test_checkpoint,
|
| 1108 |
+
test_csv_path=args.test_csv,
|
| 1109 |
+
output_dir=args.output_dir,
|
| 1110 |
+
hf_repo=args.hf_repo,
|
| 1111 |
+
hf_token=args.hf_token,
|
| 1112 |
+
max_seq_length=args.max_seq_length,
|
| 1113 |
+
lora_rank=args.lora_rank,
|
| 1114 |
+
max_steps=args.max_steps,
|
| 1115 |
+
num_generations=args.num_generations,
|
| 1116 |
+
learning_rate=args.learning_rate,
|
| 1117 |
+
temperature=args.temperature,
|
| 1118 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 1119 |
+
gpu_memory_utilization=args.gpu_memory_utilization,
|
| 1120 |
+
min_agreement=args.min_agreement,
|
| 1121 |
+
use_augmentation=not args.no_augment,
|
| 1122 |
+
push_to_hub=args.push_to_hub,
|
| 1123 |
+
merge_16bit=args.merge_16bit,
|
| 1124 |
+
fast_inference=not args.no_fast_inference,
|
| 1125 |
+
dry_run=args.dry_run,
|
| 1126 |
+
seed=args.seed,
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
if __name__ == "__main__":
|
| 1131 |
+
main()
|