vgtomahawk commited on
Commit
a7d8dc8
Β·
verified Β·
1 Parent(s): a4f7595

Upload train_sft_qwen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_qwen.py +105 -0
train_sft_qwen.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # ]
10
+ # ///
11
+
12
+ """
13
+ SFT Training Script for Qwen/Qwen2.5-0.5B
14
+
15
+ This script fine-tunes Qwen/Qwen2.5-0.5B using Supervised Fine-Tuning (SFT)
16
+ with LoRA for efficient training on the Capybara dataset.
17
+
18
+ Features:
19
+ - Trackio integration for real-time monitoring
20
+ - LoRA/PEFT for memory-efficient training
21
+ - Automatic Hub saving with checkpoints
22
+ - Train/eval split for progress monitoring
23
+ """
24
+
25
+ import trackio
26
+ from datasets import load_dataset
27
+ from peft import LoraConfig
28
+ from trl import SFTTrainer, SFTConfig
29
+
30
+ # Load dataset
31
+ print("πŸ“¦ Loading dataset...")
32
+ dataset = load_dataset("trl-lib/Capybara", split="train")
33
+ print(f"βœ… Dataset loaded: {len(dataset)} examples")
34
+
35
+ # Create train/eval split
36
+ print("πŸ”€ Creating train/eval split...")
37
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
38
+ train_dataset = dataset_split["train"]
39
+ eval_dataset = dataset_split["test"]
40
+ print(f" Train: {len(train_dataset)} examples")
41
+ print(f" Eval: {len(eval_dataset)} examples")
42
+
43
+ # Training configuration
44
+ config = SFTConfig(
45
+ # CRITICAL: Hub settings
46
+ output_dir="qwen-capybara-sft",
47
+ push_to_hub=True,
48
+ hub_model_id="vgtomahawk/qwen-capybara-sft",
49
+ hub_strategy="every_save", # Push checkpoints
50
+
51
+ # Training parameters
52
+ num_train_epochs=3,
53
+ per_device_train_batch_size=4,
54
+ gradient_accumulation_steps=4,
55
+ learning_rate=2e-5,
56
+
57
+ # Logging & checkpointing
58
+ logging_steps=10,
59
+ save_strategy="steps",
60
+ save_steps=100,
61
+ save_total_limit=2,
62
+
63
+ # Evaluation
64
+ eval_strategy="steps",
65
+ eval_steps=100,
66
+
67
+ # Optimization
68
+ warmup_ratio=0.1,
69
+ lr_scheduler_type="cosine",
70
+
71
+ # Monitoring with Trackio
72
+ report_to="trackio",
73
+ project="qwen-sft-training",
74
+ run_name="qwen-0.5b-capybara-baseline",
75
+ )
76
+
77
+ # LoRA configuration for efficient training
78
+ peft_config = LoraConfig(
79
+ r=16,
80
+ lora_alpha=32,
81
+ lora_dropout=0.05,
82
+ bias="none",
83
+ task_type="CAUSAL_LM",
84
+ target_modules=["q_proj", "v_proj"],
85
+ )
86
+
87
+ # Initialize trainer
88
+ print("🎯 Initializing trainer...")
89
+ trainer = SFTTrainer(
90
+ model="Qwen/Qwen2.5-0.5B",
91
+ train_dataset=train_dataset,
92
+ eval_dataset=eval_dataset,
93
+ args=config,
94
+ peft_config=peft_config,
95
+ )
96
+
97
+ print("πŸš€ Starting training...")
98
+ trainer.train()
99
+
100
+ print("πŸ’Ύ Pushing final model to Hub...")
101
+ trainer.push_to_hub()
102
+
103
+ print("βœ… Training complete!")
104
+ print(f"πŸ“¦ Model: https://huggingface.co/vgtomahawk/qwen-capybara-sft")
105
+ print(f"πŸ“Š Metrics: https://huggingface.co/spaces/vgtomahawk/trackio")