#!/usr/bin/env python3 """Phase 2: SFT training on Qwen3-4B""" import os import time import torch from pathlib import Path from datasets import load_from_disk from transformers import TrainingArguments, Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model # Config BASE_MODEL = "Qwen/Qwen3-4B" DATA_DIR = Path("./qwen3_pipeline/data") CKPT_DIR = Path("./qwen3_pipeline/checkpoint") CKPT_DIR.mkdir(parents=True, exist_ok=True) EPOCHS = 1 BATCH_SIZE = 2 GRAD_ACCUM = 8 LR = 2e-4 MAX_SEQ_LEN = 4096 LORA_RANK = 32 LORA_ALPHA = 64 LORA_TARGETS = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] print("="*70) print("PHASE 2: SFT TRAINING") print("="*70) # [1/4] Load model print(f"\n[1/4] Loading {BASE_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, attn_implementation="eager" ) print(f" Model loaded") print(f" GPU memory: {torch.cuda.memory_allocated()/1e9:.1f} GB") # [2/4] Apply LoRA print(f"\n[2/4] Applying LoRA...") lora_config = LoraConfig( r=LORA_RANK, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS, lora_dropout=0.0, bias="none", task_type="CAUSAL_LM", init_lora_weights="gaussian", use_rslora=True, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Enable input gradients for LoRA model.enable_input_require_grads() # [3/4] Load and tokenize data print(f"\n[3/4] Loading and tokenizing data...") dataset = load_from_disk(str(DATA_DIR / "sft")) print(f" Dataset: {len(dataset)} samples") def tokenize_function(examples): # Format messages using chat template texts = [] for msg in examples["messages"]: text = tokenizer.apply_chat_template( msg, tokenize=False, add_generation_prompt=False ) texts.append(text + tokenizer.eos_token) # Tokenize with padding and truncation result = tokenizer( texts, truncation=True, max_length=MAX_SEQ_LEN, padding="max_length", return_tensors=None ) # Labels = input_ids (simple list, not nested) result["labels"] = result["input_ids"].copy() return result print(" Tokenizing...") tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names, desc="Tokenizing", num_proc=4 ) print(f" Tokenized: {len(tokenized_dataset)} samples") # [4/4] Train print(f"\n[4/4] Training...") steps_per_epoch = len(tokenized_dataset) // (BATCH_SIZE * GRAD_ACCUM) total_steps = steps_per_epoch * EPOCHS print(f" Batch size: {BATCH_SIZE}") print(f" Grad accum: {GRAD_ACCUM}") print(f" Effective batch: {BATCH_SIZE * GRAD_ACCUM}") print(f" Steps per epoch: {steps_per_epoch}") print(f" Total steps: {total_steps}") print(f" Learning rate: {LR}") print(f" Estimated time: ~30-40 min") training_args = TrainingArguments( output_dir=str(CKPT_DIR), num_train_epochs=EPOCHS, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=LR, lr_scheduler_type="cosine", warmup_ratio=0.03, weight_decay=0.01, bf16=True, logging_steps=10, save_strategy="no", optim="adamw_torch", gradient_checkpointing=True, seed=42, report_to="none", dataloader_num_workers=4, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) print(f"\n{'='*70}") print("TRAINING STARTED") print(f"{'='*70}\n") start = time.time() trainer.train() elapsed = (time.time() - start) / 60 print(f"\n{'='*70}") print(f"✓ TRAINING COMPLETE: {elapsed:.1f} minutes") print(f"{'='*70}") # Save print(f"\nSaving model...") adapter_path = CKPT_DIR / "adapter" model.save_pretrained(str(adapter_path)) tokenizer.save_pretrained(str(adapter_path)) print(f" ✓ Adapter: {adapter_path}") # Merge print(f"\nMerging LoRA weights...") model = model.merge_and_unload() merged_path = CKPT_DIR / "merged" model.save_pretrained(str(merged_path)) tokenizer.save_pretrained(str(merged_path)) print(f" ✓ Merged: {merged_path}") del model, trainer torch.cuda.empty_cache() print(f"\n{'='*70}") print(f"✓ PHASE 2 COMPLETE") print(f"{'='*70}") print(f"\nTime: {elapsed:.1f} minutes") print(f"Cost: ~${elapsed/60 * 1.15:.2f}") print(f"\n➡️ Next: python phase3_eval.py")