# /// script # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "bitsandbytes"] # /// import os import torch from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig from transformers import AutoModelForCausalLM, AutoTokenizer import trackio # Disable tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" print("="*60) print("Fine-tuning Qwen3-0.6B on WirelessMATHBench-XL") print("Method: SFT with LoRA + Reasoning Generation") print("Dataset: Wireless Communications Math") print("Fix: Preserves capability") print("="*60) # Load WirelessMATHBench-XL dataset print("\nLoading WirelessMATHBench-XL dataset...") train_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='train') eval_dataset = load_dataset('XINLI1997/WirelessMATHBench-XL', split='test') print(f"Train examples: {len(train_dataset)}") print(f"Eval examples: {len(eval_dataset)}") # Load Teacher Model for Reasoning Generation (Preprocessing Step) TEACHER_MODEL = "Qwen/Qwen2.5-3B-Instruct" print(f"\n{'='*60}") print(f"STEP 1: Generating Reasoning Steps (Preserves )") print(f"Teacher Model: {TEACHER_MODEL}") print(f"{'='*60}") teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL, trust_remote_code=True) teacher_model = AutoModelForCausalLM.from_pretrained( TEACHER_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) teacher_model.eval() print("✓ Teacher model loaded for reasoning generation\n") def generate_reasoning_batch(examples): """Generate reasoning steps using teacher model (batch processing)""" prompts = examples['prompt'] answers = examples['correct_answer'] # Create reasoning prompts reasoning_prompts = [] for prompt in prompts: reasoning_prompt = f"""<|im_start|>user {prompt} Solve step-by-step. Put reasoning in tags, then give final answer.<|im_end|> <|im_start|>assistant """ reasoning_prompts.append(reasoning_prompt) # Generate with teacher inputs = teacher_tokenizer( reasoning_prompts, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(teacher_model.device) with torch.no_grad(): outputs = teacher_model.generate( **inputs, max_new_tokens=300, do_sample=False, pad_token_id=teacher_tokenizer.pad_token_id, ) # Process responses responses_with_reasoning = [] for i, output in enumerate(outputs): generated_ids = output[inputs['input_ids'][i].shape[0]:] response = teacher_tokenizer.decode(generated_ids, skip_special_tokens=False) # Ensure format: reasoning\n\nanswer if '' not in response: response = response.strip() + f"\n\n\n{answers[i]}" elif answers[i] not in response: response = response.strip() + f"\n\n{answers[i]}" responses_with_reasoning.append(response) return {"reasoning_answer": responses_with_reasoning} print("Generating reasoning for training set (this may take time)...") train_dataset = train_dataset.map( generate_reasoning_batch, batched=True, batch_size=4, desc="Generating reasoning" ) print("Generating reasoning for eval set...") eval_dataset = eval_dataset.map( generate_reasoning_batch, batched=True, batch_size=4, desc="Generating reasoning" ) print("✓ Reasoning generation complete!\n") # Clean up teacher model to free memory del teacher_model del teacher_tokenizer torch.cuda.empty_cache() print("✓ Teacher model unloaded\n") def format_for_sft(example): """Format augmented data for SFT training""" prompt = example['prompt'] answer_with_reasoning = example['reasoning_answer'] messages = [ {'role': 'user', 'content': prompt}, {'role': 'assistant', 'content': answer_with_reasoning} ] return {'messages': messages} print(f"{'='*60}") print(f"STEP 2: Formatting for SFT Training") print(f"{'='*60}\n") train_dataset = train_dataset.map( format_for_sft, remove_columns=train_dataset.column_names ) eval_dataset = eval_dataset.map( format_for_sft, remove_columns=eval_dataset.column_names ) print("✓ Dataset formatted with reasoning preserved") # Configure LoRA for efficient fine-tuning print("\nConfiguring LoRA...") peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], bias="none", task_type="CAUSAL_LM" ) # Configure SFT training print("Configuring training arguments...") training_args = SFTConfig( output_dir="qwen3-wireless-math", # Training hyperparameters num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, gradient_accumulation_steps=4, # Effective batch size = 16 # Optimization learning_rate=2e-4, lr_scheduler_type="cosine", warmup_ratio=0.1, weight_decay=0.01, # Evaluation and saving eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=200, save_total_limit=3, # Logging and monitoring logging_steps=10, report_to="trackio", run_name="qwen3-0.6b-wireless-math-reasoning", project="wireless-math-finetuning", # Memory optimization gradient_checkpointing=False, # Disabled to avoid gradient computation issues bf16=True, # Hub integration push_to_hub=True, hub_model_id="wlabchoi/qwen3-0.6b-wireless-math-reasoning", hub_strategy="every_save", hub_private_repo=False, # Performance dataloader_num_workers=0, # Avoid multiprocessing issues remove_unused_columns=False, ) # Initialize trainer print("\nInitializing SFT Trainer...") trainer = SFTTrainer( model="Qwen/Qwen3-0.6B", train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=peft_config, args=training_args, ) # Start training print("\n" + "="*60) print("STEP 3: SFT Training on Reasoning-Augmented Data") print("="*60) print(f"Model: Qwen3-0.6B") print(f"Dataset: WirelessMATHBench-XL (with generated reasoning)") print(f"Train: {len(train_dataset)} examples") print(f"Eval: {len(eval_dataset)} examples") print(f"Epochs: 3") print(f"Result: Model preserves capability") print("="*60 + "\n") trainer.train() # Push final model to Hub print("\nPushing final model to Hub...") trainer.push_to_hub(commit_message="SFT complete - Qwen3-0.6B on WirelessMATH with reasoning preservation") print("\n" + "="*60) print("✓ Fine-Tuning Complete - Reasoning Preserved!") print("="*60) print("Model now:") print(" ✓ Knows wireless communications mathematics") print(" ✓ Maintains chain-of-thought") print(" ✓ Shows reasoning steps before answers") print("="*60)