|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from peft import LoraConfig |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import trackio |
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
print("="*60) |
|
|
print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL") |
|
|
print("Method: SFT with LoRA + Reasoning Generation") |
|
|
print("Dataset: Wireless Communications Math") |
|
|
print("Fix: Preserves <think></think> capability") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
print("\nLoading WirelessMATHBench-XL dataset...") |
|
|
train_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='train') |
|
|
eval_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='test') |
|
|
|
|
|
print(f"Train examples: {len(train_dataset)}") |
|
|
print(f"Eval examples: {len(eval_dataset)}") |
|
|
|
|
|
|
|
|
TEACHER_MODEL = "Qwen/Qwen2.5-3B-Instruct" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"STEP 1: Generating Reasoning Steps (Preserves <think></think>)") |
|
|
print(f"Teacher Model: {TEACHER_MODEL}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, trust_remote_code=True) |
|
|
teacher_model = AutoModelForCausalLM.from_pretrained( |
|
|
TEACHER_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
teacher_model.eval() |
|
|
print("β Teacher model loaded for reasoning generation\n") |
|
|
|
|
|
def generate_reasoning_batch(examples): |
|
|
"""Generate reasoning steps using teacher model (batch processing)""" |
|
|
prompts = examples['prompt'] |
|
|
answers = examples['correct_answer'] |
|
|
|
|
|
|
|
|
reasoning_prompts = [] |
|
|
for prompt in prompts: |
|
|
reasoning_prompt = f"""<|im_start|>user |
|
|
{prompt} |
|
|
|
|
|
Solve step-by-step. Put reasoning in <think></think> tags, then give final answer.<|im_end|> |
|
|
<|im_start|>assistant |
|
|
<think>""" |
|
|
reasoning_prompts.append(reasoning_prompt) |
|
|
|
|
|
|
|
|
inputs = teacher_tokenizer( |
|
|
reasoning_prompts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
).to(teacher_model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = teacher_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=300, |
|
|
do_sample=False, |
|
|
pad_token_id=teacher_tokenizer.pad_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
responses_with_reasoning = [] |
|
|
for i, output in enumerate(outputs): |
|
|
generated_ids = output[inputs['input_ids'][i].shape[0]:] |
|
|
response = teacher_tokenizer.decode(generated_ids, skip_special_tokens=False) |
|
|
|
|
|
|
|
|
if '</think>' not in response: |
|
|
response = response.strip() + f"\n</think>\n\n{answers[i]}" |
|
|
elif answers[i] not in response: |
|
|
response = response.strip() + f"\n\n{answers[i]}" |
|
|
|
|
|
responses_with_reasoning.append(response) |
|
|
|
|
|
return {"reasoning_answer": responses_with_reasoning} |
|
|
|
|
|
print("Generating reasoning for training set (this may take time)...") |
|
|
train_dataset = train_dataset.map( |
|
|
generate_reasoning_batch, |
|
|
batched=True, |
|
|
batch_size=4, |
|
|
desc="Generating reasoning" |
|
|
) |
|
|
|
|
|
print("Generating reasoning for eval set...") |
|
|
eval_dataset = eval_dataset.map( |
|
|
generate_reasoning_batch, |
|
|
batched=True, |
|
|
batch_size=4, |
|
|
desc="Generating reasoning" |
|
|
) |
|
|
|
|
|
print("β Reasoning generation complete!\n") |
|
|
|
|
|
|
|
|
del teacher_model |
|
|
del teacher_tokenizer |
|
|
torch.cuda.empty_cache() |
|
|
print("β Teacher model unloaded\n") |
|
|
|
|
|
def format_for_sft(example): |
|
|
"""Format augmented data for SFT training""" |
|
|
prompt = example['prompt'] |
|
|
answer_with_reasoning = example['reasoning_answer'] |
|
|
|
|
|
messages = [ |
|
|
{'role': 'user', 'content': prompt}, |
|
|
{'role': 'assistant', 'content': answer_with_reasoning} |
|
|
] |
|
|
|
|
|
return {'messages': messages} |
|
|
|
|
|
print(f"{'='*60}") |
|
|
print(f"STEP 2: Formatting for SFT Training") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
train_dataset = train_dataset.map( |
|
|
format_for_sft, |
|
|
remove_columns=train_dataset.column_names |
|
|
) |
|
|
eval_dataset = eval_dataset.map( |
|
|
format_for_sft, |
|
|
remove_columns=eval_dataset.column_names |
|
|
) |
|
|
|
|
|
print("β Dataset formatted with reasoning preserved") |
|
|
|
|
|
|
|
|
print("\nConfiguring LoRA...") |
|
|
peft_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM" |
|
|
) |
|
|
|
|
|
|
|
|
print("Configuring training arguments...") |
|
|
training_args = SFTConfig( |
|
|
output_dir="qwen3-wireless-math", |
|
|
|
|
|
|
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=4, |
|
|
per_device_eval_batch_size=4, |
|
|
gradient_accumulation_steps=4, |
|
|
|
|
|
|
|
|
learning_rate=2e-4, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
weight_decay=0.01, |
|
|
|
|
|
|
|
|
eval_strategy="steps", |
|
|
eval_steps=100, |
|
|
save_strategy="steps", |
|
|
save_steps=200, |
|
|
save_total_limit=3, |
|
|
|
|
|
|
|
|
logging_steps=10, |
|
|
report_to="trackio", |
|
|
run_name="qwen3-0.6b-wireless-math-reasoning", |
|
|
project="wireless-math-finetuning", |
|
|
|
|
|
|
|
|
gradient_checkpointing=False, |
|
|
bf16=True, |
|
|
|
|
|
|
|
|
push_to_hub=True, |
|
|
hub_model_id="wlabchoi/qwen3-0.6b-wireless-math-reasoning", |
|
|
hub_strategy="every_save", |
|
|
hub_private_repo=False, |
|
|
|
|
|
|
|
|
dataloader_num_workers=0, |
|
|
remove_unused_columns=False, |
|
|
) |
|
|
|
|
|
|
|
|
print("\nInitializing SFT Trainer...") |
|
|
trainer = SFTTrainer( |
|
|
model="Qwen/Qwen3-0.6B", |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
peft_config=peft_config, |
|
|
args=training_args, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("STEP 3: SFT Training on Reasoning-Augmented Data") |
|
|
print("="*60) |
|
|
print(f"Model: Qwen3-0.6B") |
|
|
print(f"Dataset: WirelessMATHBench-XL (with generated reasoning)") |
|
|
print(f"Train: {len(train_dataset)} examples") |
|
|
print(f"Eval: {len(eval_dataset)} examples") |
|
|
print(f"Epochs: 3") |
|
|
print(f"Result: Model preserves <think></think> capability") |
|
|
print("="*60 + "\n") |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
print("\nPushing final model to Hub...") |
|
|
trainer.push_to_hub(commit_message="SFT complete - Qwen3-0.6B on WirelessMATH with reasoning preservation") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("β Fine-Tuning Complete - Reasoning Preserved!") |
|
|
print("="*60) |
|
|
print("Model now:") |
|
|
print(" β Knows wireless communications mathematics") |
|
|
print(" β Maintains <think></think> chain-of-thought") |
|
|
print(" β Shows reasoning steps before answers") |
|
|
print("="*60) |
|
|
|