# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "bitsandbytes"]
# ///
import os
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import trackio
# Disable tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print("="*60)
print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL")
print("Method: SFT with LoRA + Reasoning Generation")
print("Dataset: Wireless Communications Math")
print("Fix: Preserves capability")
print("="*60)
# Load WirelessMATHBench-XL dataset
print("\nLoading WirelessMATHBench-XL dataset...")
train_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='train')
eval_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='test')
print(f"Train examples: {len(train_dataset)}")
print(f"Eval examples: {len(eval_dataset)}")
# Load Teacher Model for Reasoning Generation (Preprocessing Step)
TEACHER_MODEL = "Qwen/Qwen2.5-3B-Instruct"
print(f"\n{'='*60}")
print(f"STEP 1: Generating Reasoning Steps (Preserves )")
print(f"Teacher Model: {TEACHER_MODEL}")
print(f"{'='*60}")
teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, trust_remote_code=True)
teacher_model = AutoModelForCausalLM.from_pretrained(
TEACHER_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
teacher_model.eval()
print("✓ Teacher model loaded for reasoning generation\n")
def generate_reasoning_batch(examples):
"""Generate reasoning steps using teacher model (batch processing)"""
prompts = examples['prompt']
answers = examples['correct_answer']
# Create reasoning prompts
reasoning_prompts = []
for prompt in prompts:
reasoning_prompt = f"""<|im_start|>user
{prompt}
Solve step-by-step. Put reasoning in tags, then give final answer.<|im_end|>
<|im_start|>assistant
"""
reasoning_prompts.append(reasoning_prompt)
# Generate with teacher
inputs = teacher_tokenizer(
reasoning_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(teacher_model.device)
with torch.no_grad():
outputs = teacher_model.generate(
**inputs,
max_new_tokens=300,
do_sample=False,
pad_token_id=teacher_tokenizer.pad_token_id,
)
# Process responses
responses_with_reasoning = []
for i, output in enumerate(outputs):
generated_ids = output[inputs['input_ids'][i].shape[0]:]
response = teacher_tokenizer.decode(generated_ids, skip_special_tokens=False)
# Ensure format: reasoning\n\nanswer
if '' not in response:
response = response.strip() + f"\n\n\n{answers[i]}"
elif answers[i] not in response:
response = response.strip() + f"\n\n{answers[i]}"
responses_with_reasoning.append(response)
return {"reasoning_answer": responses_with_reasoning}
print("Generating reasoning for training set (this may take time)...")
train_dataset = train_dataset.map(
generate_reasoning_batch,
batched=True,
batch_size=4,
desc="Generating reasoning"
)
print("Generating reasoning for eval set...")
eval_dataset = eval_dataset.map(
generate_reasoning_batch,
batched=True,
batch_size=4,
desc="Generating reasoning"
)
print("✓ Reasoning generation complete!\n")
# Clean up teacher model to free memory
del teacher_model
del teacher_tokenizer
torch.cuda.empty_cache()
print("✓ Teacher model unloaded\n")
def format_for_sft(example):
"""Format augmented data for SFT training"""
prompt = example['prompt']
answer_with_reasoning = example['reasoning_answer']
messages = [
{'role': 'user', 'content': prompt},
{'role': 'assistant', 'content': answer_with_reasoning}
]
return {'messages': messages}
print(f"{'='*60}")
print(f"STEP 2: Formatting for SFT Training")
print(f"{'='*60}\n")
train_dataset = train_dataset.map(
format_for_sft,
remove_columns=train_dataset.column_names
)
eval_dataset = eval_dataset.map(
format_for_sft,
remove_columns=eval_dataset.column_names
)
print("✓ Dataset formatted with reasoning preserved")
# Configure LoRA for efficient fine-tuning
print("\nConfiguring LoRA...")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
task_type="CAUSAL_LM"
)
# Configure SFT training
print("Configuring training arguments...")
training_args = SFTConfig(
output_dir="qwen3-wireless-math",
# Training hyperparameters
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4, # Effective batch size = 16
# Optimization
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
# Evaluation and saving
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
# Logging and monitoring
logging_steps=10,
report_to="trackio",
run_name="qwen3-0.6b-wireless-math-reasoning",
project="wireless-math-finetuning",
# Memory optimization
gradient_checkpointing=False, # Disabled to avoid gradient computation issues
bf16=True,
# Hub integration
push_to_hub=True,
hub_model_id="wlabchoi/qwen3-0.6b-wireless-math-reasoning",
hub_strategy="every_save",
hub_private_repo=False,
# Performance
dataloader_num_workers=0, # Avoid multiprocessing issues
remove_unused_columns=False,
)
# Initialize trainer
print("\nInitializing SFT Trainer...")
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
args=training_args,
)
# Start training
print("\n" + "="*60)
print("STEP 3: SFT Training on Reasoning-Augmented Data")
print("="*60)
print(f"Model: Qwen3-0.6B")
print(f"Dataset: WirelessMATHBench-XL (with generated reasoning)")
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
print(f"Epochs: 3")
print(f"Result: Model preserves capability")
print("="*60 + "\n")
trainer.train()
# Push final model to Hub
print("\nPushing final model to Hub...")
trainer.push_to_hub(commit_message="SFT complete - Qwen3-0.6B on WirelessMATH with reasoning preservation")
print("\n" + "="*60)
print("✓ Fine-Tuning Complete - Reasoning Preserved!")
print("="*60)
print("Model now:")
print(" ✓ Knows wireless communications mathematics")
print(" ✓ Maintains chain-of-thought")
print(" ✓ Shows reasoning steps before answers")
print("="*60)