mattPearce commited on
Commit
ee11a16
·
verified ·
1 Parent(s): 6aeaa4b

Upload train_genomics_lora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_genomics_lora.py +135 -0
train_genomics_lora.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "trackio",
6
+ # "torch",
7
+ # "transformers>=4.44.0",
8
+ # "datasets",
9
+ # "accelerate",
10
+ # ]
11
+ # ///
12
+
13
+ """
14
+ LoRA Fine-tuning for Qwen2.5-72B on Consumer Genomics Data
15
+ Trains model to interpret SNP data and provide health insights
16
+ """
17
+
18
+ from datasets import load_dataset
19
+ from peft import LoraConfig
20
+ from trl import SFTTrainer, SFTConfig
21
+ import trackio
22
+ import torch
23
+
24
+ print("="*80)
25
+ print("Qwen2.5-72B LoRA Fine-tuning for Genomics Interpretation")
26
+ print("="*80)
27
+
28
+ # Load dataset from Hub
29
+ print("\n[1/4] Loading dataset...")
30
+ dataset = load_dataset("mattPearce/genellm-genomics-finetune", split="train")
31
+ print(f"✓ Loaded {len(dataset)} training examples")
32
+
33
+ # Create train/eval split for monitoring training progress
34
+ print("\n[2/4] Creating train/eval split...")
35
+ dataset_split = dataset.train_test_split(test_size=0.05, seed=42) # 5% for eval (25 examples)
36
+ print(f"✓ Train: {len(dataset_split['train'])} examples")
37
+ print(f"✓ Eval: {len(dataset_split['test'])} examples")
38
+
39
+ # LoRA configuration optimized for 72B model
40
+ print("\n[3/4] Configuring LoRA...")
41
+ peft_config = LoraConfig(
42
+ r=32, # Rank - higher for better quality on large models
43
+ lora_alpha=64, # Scaling factor (typically 2x rank)
44
+ target_modules=[ # Apply LoRA to all attention layers
45
+ "q_proj",
46
+ "k_proj",
47
+ "v_proj",
48
+ "o_proj",
49
+ "gate_proj",
50
+ "up_proj",
51
+ "down_proj"
52
+ ],
53
+ lora_dropout=0.05,
54
+ bias="none",
55
+ task_type="CAUSAL_LM"
56
+ )
57
+ print(f"✓ LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
58
+ print(f"✓ Target modules: {len(peft_config.target_modules)} layer types")
59
+
60
+ # Training configuration
61
+ print("\n[4/4] Setting up training...")
62
+ training_args = SFTConfig(
63
+ # Model output
64
+ output_dir="qwen2.5-72b-genomics-lora",
65
+
66
+ # Hub configuration - CRITICAL for saving results
67
+ push_to_hub=True,
68
+ hub_model_id="mattPearce/qwen2.5-72b-genomics-lora",
69
+ hub_strategy="every_save", # Push checkpoints to Hub
70
+ hub_private_repo=False,
71
+
72
+ # Training hyperparameters
73
+ num_train_epochs=3,
74
+ per_device_train_batch_size=1, # Small batch for 72B model
75
+ gradient_accumulation_steps=16, # Effective batch size = 16
76
+ learning_rate=2e-4, # Standard for LoRA
77
+ lr_scheduler_type="cosine",
78
+ warmup_steps=50,
79
+
80
+ # Optimization for memory efficiency
81
+ gradient_checkpointing=True,
82
+ bf16=True, # Use bfloat16 for better stability
83
+
84
+ # Evaluation strategy
85
+ eval_strategy="steps",
86
+ eval_steps=25, # Evaluate every 25 steps
87
+
88
+ # Checkpointing
89
+ save_strategy="steps",
90
+ save_steps=50, # Save every 50 steps
91
+ save_total_limit=3, # Keep only 3 most recent checkpoints
92
+
93
+ # Logging and monitoring
94
+ logging_steps=5,
95
+ report_to="trackio",
96
+ run_name="qwen2.5-72b-genomics-v1",
97
+
98
+ # Misc
99
+ seed=42,
100
+ remove_unused_columns=True,
101
+ )
102
+
103
+ print(f"✓ Training config:")
104
+ print(f" Epochs: {training_args.num_train_epochs}")
105
+ print(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
106
+ print(f" Learning rate: {training_args.learning_rate}")
107
+ print(f" Hub model ID: {training_args.hub_model_id}")
108
+
109
+ # Initialize trainer
110
+ print("\n[Starting Training]")
111
+ print("Model: Qwen/Qwen2.5-72B")
112
+ print("Method: LoRA fine-tuning")
113
+ print("Trackio monitoring: https://huggingface.co/spaces/mattPearce/trackio")
114
+ print("="*80)
115
+
116
+ trainer = SFTTrainer(
117
+ model="Qwen/Qwen2.5-72B",
118
+ train_dataset=dataset_split["train"],
119
+ eval_dataset=dataset_split["test"],
120
+ peft_config=peft_config,
121
+ args=training_args,
122
+ )
123
+
124
+ # Train the model
125
+ trainer.train()
126
+
127
+ # Save final model to Hub
128
+ print("\n[Finalizing]")
129
+ print("Pushing final model to Hub...")
130
+ trainer.push_to_hub()
131
+
132
+ print("\n" + "="*80)
133
+ print("✓ Training completed successfully!")
134
+ print(f"✓ Model saved to: https://huggingface.co/mattPearce/qwen2.5-72b-genomics-lora")
135
+ print("="*80)