OliverSlivka commited on
Commit
5973bed
Β·
verified Β·
1 Parent(s): 3fdc9a0

Upload run_sft_simplified.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_sft_simplified.py +174 -0
run_sft_simplified.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simplified SFT training script for Qwen2.5-0.5B-Instruct
3
+ Based on official HuggingFace TRL examples
4
+ Dataset loaded from GitHub to avoid Hub caching issues
5
+ """
6
+
7
+ import subprocess
8
+ import torch
9
+ from datasets import load_from_disk
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
11
+ from peft import LoraConfig
12
+ from trl import SFTTrainer, SFTConfig
13
+
14
+ # ===== 1. Clone Dataset from GitHub =====
15
+ GIT_TOKEN = "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"
16
+ GIT_REPO_URL = f"https://{GIT_TOKEN}@github.com/oliversl1vka/itemsety-qwen-finetuning.git"
17
+ CLONE_PATH = "/tmp/itemsety-qwen-finetuning"
18
+ DATASET_PATH = f"{CLONE_PATH}/hf_dataset_enhanced"
19
+
20
+ print("πŸ“¦ Cloning dataset from private GitHub repo...")
21
+ subprocess.run(['git', 'clone', GIT_REPO_URL, CLONE_PATH], check=True)
22
+ print("βœ… Clone complete")
23
+
24
+ # Security: Remove .git to avoid token exposure
25
+ subprocess.run(['rm', '-rf', f"{CLONE_PATH}/.git"], check=True)
26
+ print("πŸ” Removed .git directory")
27
+
28
+ # ===== 2. Load Dataset =====
29
+ print(f"πŸ’Ύ Loading dataset from {DATASET_PATH}...")
30
+ dataset = load_from_disk(DATASET_PATH)
31
+ train_dataset = dataset["train"]
32
+ eval_dataset = dataset["validation"]
33
+
34
+ print(f"βœ… Dataset loaded: {len(train_dataset)} train, {len(eval_dataset)} eval examples")
35
+ print(f" Columns: {train_dataset.column_names}")
36
+ print(f" First example keys: {list(train_dataset[0].keys())}")
37
+
38
+ # ===== 3. Load Model with 4-bit Quantization =====
39
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
40
+ OUTPUT_DIR = "OliverSlivka/qwen-itemsety-qlora"
41
+
42
+ print(f"πŸ”₯ Loading {MODEL_NAME} with 4-bit quantization...")
43
+
44
+ # 4-bit quantization config
45
+ bnb_config = BitsAndBytesConfig(
46
+ load_in_4bit=True,
47
+ bnb_4bit_quant_type="nf4",
48
+ bnb_4bit_compute_dtype=torch.bfloat16,
49
+ bnb_4bit_use_double_quant=True,
50
+ )
51
+
52
+ # Load model
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ MODEL_NAME,
55
+ quantization_config=bnb_config,
56
+ device_map="auto",
57
+ trust_remote_code=True,
58
+ )
59
+
60
+ # Load tokenizer
61
+ tokenizer = AutoTokenizer.from_pretrained(
62
+ MODEL_NAME,
63
+ trust_remote_code=True,
64
+ )
65
+ if tokenizer.pad_token is None:
66
+ tokenizer.pad_token = tokenizer.eos_token
67
+
68
+ print("βœ… Model and tokenizer loaded with 4-bit quantization")
69
+
70
+ # ===== 4. LoRA Configuration =====
71
+ peft_config = LoraConfig(
72
+ r=16,
73
+ lora_alpha=32,
74
+ lora_dropout=0.05,
75
+ bias="none",
76
+ task_type="CAUSAL_LM",
77
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
78
+ )
79
+
80
+ print(f"🎯 LoRA config: r={peft_config.r}, alpha={peft_config.lora_alpha}")
81
+
82
+ # ===== 5. Training Configuration =====
83
+ training_args = SFTConfig(
84
+ # Output & Hub
85
+ output_dir=OUTPUT_DIR,
86
+ push_to_hub=True,
87
+ hub_model_id=OUTPUT_DIR,
88
+
89
+ # Training schedule
90
+ num_train_epochs=3,
91
+ per_device_train_batch_size=4,
92
+ gradient_accumulation_steps=4,
93
+ learning_rate=2e-4,
94
+ warmup_steps=10,
95
+ max_steps=-1, # Train for full epochs
96
+
97
+ # Optimization
98
+ optim="paged_adamw_8bit",
99
+ max_grad_norm=0.3,
100
+ gradient_checkpointing=True,
101
+
102
+ # Precision
103
+ bf16=True,
104
+
105
+ # Logging
106
+ logging_steps=5,
107
+ report_to="trackio",
108
+ trackio_space_id=OUTPUT_DIR,
109
+
110
+ # Evaluation
111
+ eval_strategy="steps",
112
+ eval_steps=20,
113
+
114
+ # Saving
115
+ save_strategy="steps",
116
+ save_steps=50,
117
+ save_total_limit=2,
118
+
119
+ # Sequence length
120
+ max_length=2048,
121
+ )
122
+
123
+ print("βœ… Training configuration set")
124
+ print(f" Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
125
+ print(f" Epochs: {training_args.num_train_epochs}")
126
+ print(f" Learning rate: {training_args.learning_rate}")
127
+
128
+ # ===== 6. Initialize Trainer =====
129
+ print("🎯 Initializing SFTTrainer...")
130
+
131
+ trainer = SFTTrainer(
132
+ model=model,
133
+ args=training_args,
134
+ train_dataset=train_dataset,
135
+ eval_dataset=eval_dataset,
136
+ peft_config=peft_config,
137
+ )
138
+
139
+ print("βœ… Trainer initialized")
140
+
141
+ # Show GPU memory before training
142
+ if torch.cuda.is_available():
143
+ gpu_stats = torch.cuda.get_device_properties(0)
144
+ start_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
145
+ max_memory = round(gpu_stats.total_memory / 1024**3, 3)
146
+ print(f"\nπŸ–₯️ GPU: {gpu_stats.name}")
147
+ print(f" Max memory: {max_memory} GB")
148
+ print(f" Reserved: {start_memory} GB")
149
+
150
+ # ===== 7. Train =====
151
+ print("\nπŸš€ Starting training...")
152
+ print("="*60)
153
+
154
+ trainer_stats = trainer.train()
155
+
156
+ print("="*60)
157
+ print("βœ… Training complete!")
158
+
159
+ # Show final stats
160
+ if torch.cuda.is_available():
161
+ used_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
162
+ training_memory = round(used_memory - start_memory, 3)
163
+ print(f"\nπŸ“Š Training stats:")
164
+ print(f" Runtime: {round(trainer_stats.metrics['train_runtime']/60, 2)} minutes")
165
+ print(f" Peak memory: {used_memory} GB ({round(used_memory/max_memory*100, 1)}%)")
166
+ print(f" Training memory: {training_memory} GB")
167
+
168
+ # ===== 8. Push to Hub =====
169
+ print("\nπŸ’Ύ Pushing final model to Hub...")
170
+ trainer.push_to_hub()
171
+ print(f"βœ… Model pushed to: https://huggingface.co/{OUTPUT_DIR}")
172
+ print(f"πŸ“Š View training metrics at: https://huggingface.co/spaces/{OUTPUT_DIR}")
173
+
174
+ print("\nπŸŽ‰ All done!")