File size: 4,087 Bytes
ee11a16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # /// script
# dependencies = [
# "trl>=0.12.0",
# "peft>=0.7.0",
# "trackio",
# "torch",
# "transformers>=4.44.0",
# "datasets",
# "accelerate",
# ]
# ///
"""
LoRA Fine-tuning for Qwen2.5-72B on Consumer Genomics Data
Trains model to interpret SNP data and provide health insights
"""
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
import torch
print("="*80)
print("Qwen2.5-72B LoRA Fine-tuning for Genomics Interpretation")
print("="*80)
# Load dataset from Hub
print("\n[1/4] Loading dataset...")
dataset = load_dataset("mattPearce/genellm-genomics-finetune", split="train")
print(f"✓ Loaded {len(dataset)} training examples")
# Create train/eval split for monitoring training progress
print("\n[2/4] Creating train/eval split...")
dataset_split = dataset.train_test_split(test_size=0.05, seed=42) # 5% for eval (25 examples)
print(f"✓ Train: {len(dataset_split['train'])} examples")
print(f"✓ Eval: {len(dataset_split['test'])} examples")
# LoRA configuration optimized for 72B model
print("\n[3/4] Configuring LoRA...")
peft_config = LoraConfig(
r=32, # Rank - higher for better quality on large models
lora_alpha=64, # Scaling factor (typically 2x rank)
target_modules=[ # Apply LoRA to all attention layers
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj"
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
print(f"✓ LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
print(f"✓ Target modules: {len(peft_config.target_modules)} layer types")
# Training configuration
print("\n[4/4] Setting up training...")
training_args = SFTConfig(
# Model output
output_dir="qwen2.5-72b-genomics-lora",
# Hub configuration - CRITICAL for saving results
push_to_hub=True,
hub_model_id="mattPearce/qwen2.5-72b-genomics-lora",
hub_strategy="every_save", # Push checkpoints to Hub
hub_private_repo=False,
# Training hyperparameters
num_train_epochs=3,
per_device_train_batch_size=1, # Small batch for 72B model
gradient_accumulation_steps=16, # Effective batch size = 16
learning_rate=2e-4, # Standard for LoRA
lr_scheduler_type="cosine",
warmup_steps=50,
# Optimization for memory efficiency
gradient_checkpointing=True,
bf16=True, # Use bfloat16 for better stability
# Evaluation strategy
eval_strategy="steps",
eval_steps=25, # Evaluate every 25 steps
# Checkpointing
save_strategy="steps",
save_steps=50, # Save every 50 steps
save_total_limit=3, # Keep only 3 most recent checkpoints
# Logging and monitoring
logging_steps=5,
report_to="trackio",
run_name="qwen2.5-72b-genomics-v1",
# Misc
seed=42,
remove_unused_columns=True,
)
print(f"✓ Training config:")
print(f" Epochs: {training_args.num_train_epochs}")
print(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f" Learning rate: {training_args.learning_rate}")
print(f" Hub model ID: {training_args.hub_model_id}")
# Initialize trainer
print("\n[Starting Training]")
print("Model: Qwen/Qwen2.5-72B")
print("Method: LoRA fine-tuning")
print("Trackio monitoring: https://huggingface.co/spaces/mattPearce/trackio")
print("="*80)
trainer = SFTTrainer(
model="Qwen/Qwen2.5-72B",
train_dataset=dataset_split["train"],
eval_dataset=dataset_split["test"],
peft_config=peft_config,
args=training_args,
)
# Train the model
trainer.train()
# Save final model to Hub
print("\n[Finalizing]")
print("Pushing final model to Hub...")
trainer.push_to_hub()
print("\n" + "="*80)
print("✓ Training completed successfully!")
print(f"✓ Model saved to: https://huggingface.co/mattPearce/qwen2.5-72b-genomics-lora")
print("="*80)
|