qwen-sft-training / train_qwen_sft.py
davidsmts's picture
Upload train_qwen_sft.py with huggingface_hub
b744e34 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "torch"]
# ///
"""
SFT Training Script for Qwen/Qwen2.5-0.5B
Fine-tunes a small Qwen model using Supervised Fine-Tuning with LoRA
"""
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
print("πŸš€ Starting SFT training for Qwen/Qwen2.5-0.5B")
print("=" * 60)
# Load dataset - using TRL's Capybara dataset (known compatible format)
print("\nπŸ“¦ Loading dataset: trl-lib/Capybara")
dataset = load_dataset("trl-lib/Capybara", split="train")
# For demo purposes, take a small subset for quick training
print("βœ‚οΈ Taking subset of 500 examples for quick demo training")
dataset = dataset.select(range(500))
# Create train/eval split for monitoring
print("πŸ”€ Creating train/test split (90/10)")
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
print(f" Train examples: {len(dataset_split['train'])}")
print(f" Eval examples: {len(dataset_split['test'])}")
# Configure LoRA for parameter-efficient fine-tuning
print("\nβš™οΈ Configuring LoRA (r=16, alpha=32)")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
# Configure training arguments
print("\nπŸ“‹ Setting up training configuration")
training_args = SFTConfig(
# Model output settings
output_dir="qwen-0.5b-sft-demo",
# Hub settings - CRITICAL for saving results
push_to_hub=True,
hub_model_id="qwen-0.5b-sft-capybara", # Will be prefixed with username
hub_strategy="end", # Push final model at end
# Training hyperparameters
num_train_epochs=1, # Just 1 epoch for quick demo
per_device_train_batch_size=2,
gradient_accumulation_steps=4, # Effective batch size = 8
learning_rate=2e-4,
warmup_steps=10,
# Evaluation settings
eval_strategy="steps",
eval_steps=25,
# Logging and monitoring
logging_steps=10,
report_to="trackio",
run_name="qwen-0.5b-sft-demo",
# Optimization settings
gradient_checkpointing=True,
optim="adamw_torch",
# Misc settings
save_strategy="epoch",
bf16=True, # Use bfloat16 for better stability
max_grad_norm=1.0,
)
print(f" Model: Qwen/Qwen2.5-0.5B")
print(f" Epochs: {training_args.num_train_epochs}")
print(f" Batch size: {training_args.per_device_train_batch_size}")
print(f" Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f" Learning rate: {training_args.learning_rate}")
# Initialize trainer
print("\nπŸ‹οΈ Initializing SFT Trainer")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset_split["train"],
eval_dataset=dataset_split["test"],
peft_config=peft_config,
args=training_args,
)
# Start training
print("\n" + "=" * 60)
print("🎯 Starting training...")
print("=" * 60)
trainer.train()
print("\nβœ… Training completed!")
# Push final model to Hub
print("\nπŸ“€ Pushing final model to Hub...")
trainer.push_to_hub()
print("\n" + "=" * 60)
print("πŸŽ‰ Training job completed successfully!")
print("=" * 60)
print(f"πŸ“Š Model saved to: {training_args.hub_model_id}")
print("πŸ’‘ Check Trackio dashboard for detailed metrics")