Hajime MATSUMOTO commited on
Commit
1cc8a56
·
1 Parent(s): 9706c88

Add multi-GPU training script for 4xL40S

Browse files
Files changed (1) hide show
  1. train_multi_gpu.py +321 -0
train_multi_gpu.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Qwen2.5-7B + glaive-function-calling-v2 QLoRA学習スクリプト
4
+ マルチGPU対応版 (4xA10G等)
5
+
6
+ 実行方法:
7
+ accelerate launch --num_processes 4 train_multi_gpu.py
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import time
13
+ from datetime import datetime
14
+
15
+ import torch
16
+ from datasets import load_dataset
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ TrainingArguments,
22
+ )
23
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
24
+ from trl import SFTTrainer
25
+ from transformers.trainer_callback import TrainerCallback
26
+
27
+ # ============================================================
28
+ # 設定
29
+ # ============================================================
30
+ BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
31
+ OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora"
32
+ DATASET_NAME = "glaiveai/glaive-function-calling-v2"
33
+
34
+ CHECKPOINT_DIR = "./checkpoints"
35
+ FINAL_OUTPUT_DIR = "./output/final"
36
+
37
+ # ============================================================
38
+ # QLoRA量子化設定
39
+ # ============================================================
40
+ bnb_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.bfloat16,
43
+ bnb_4bit_quant_type="nf4",
44
+ bnb_4bit_use_double_quant=True,
45
+ )
46
+
47
+ # ============================================================
48
+ # LoRA設定
49
+ # ============================================================
50
+ lora_config = LoraConfig(
51
+ r=64,
52
+ lora_alpha=16,
53
+ lora_dropout=0.05,
54
+ target_modules=[
55
+ "q_proj", "k_proj", "v_proj", "o_proj",
56
+ "gate_proj", "up_proj", "down_proj"
57
+ ],
58
+ bias="none",
59
+ task_type="CAUSAL_LM",
60
+ )
61
+
62
+
63
+ # ============================================================
64
+ # カスタムコールバック
65
+ # ============================================================
66
+ class VerboseLoggingCallback(TrainerCallback):
67
+ def __init__(self):
68
+ self.start_time = None
69
+
70
+ def on_train_begin(self, args, state, control, **kwargs):
71
+ self.start_time = time.time()
72
+ if state.is_world_process_zero:
73
+ print("\n" + "=" * 70)
74
+ print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started")
75
+ print(f" Total steps: {state.max_steps}")
76
+ print(f" Num GPUs: {args.world_size}")
77
+ print(f" Per device batch: {args.per_device_train_batch_size}")
78
+ print(f" Gradient accum: {args.gradient_accumulation_steps}")
79
+ print(f" Effective batch: {args.per_device_train_batch_size * args.gradient_accumulation_steps * args.world_size}")
80
+ print("=" * 70 + "\n")
81
+
82
+ def on_log(self, args, state, control, logs=None, **kwargs):
83
+ if logs is None or not state.is_world_process_zero:
84
+ return
85
+
86
+ current_time = time.time()
87
+ elapsed = current_time - self.start_time
88
+ elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed))
89
+
90
+ progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0
91
+
92
+ if state.global_step > 0:
93
+ time_per_step = elapsed / state.global_step
94
+ remaining_steps = state.max_steps - state.global_step
95
+ eta_seconds = time_per_step * remaining_steps
96
+ eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
97
+ else:
98
+ eta_str = "calculating..."
99
+
100
+ loss = logs.get("loss", "N/A")
101
+ lr = logs.get("learning_rate", "N/A")
102
+
103
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] "
104
+ f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | "
105
+ f"Loss: {loss:.4f if isinstance(loss, float) else loss} | "
106
+ f"LR: {lr:.2e if isinstance(lr, float) else lr} | "
107
+ f"Elapsed: {elapsed_str} | ETA: {eta_str}")
108
+
109
+ def on_save(self, args, state, control, **kwargs):
110
+ if state.is_world_process_zero:
111
+ print(f"\n[{datetime.now().strftime('%H:%M:%S')}] "
112
+ f"💾 Checkpoint saved at step {state.global_step}\n")
113
+
114
+ def on_train_end(self, args, state, control, **kwargs):
115
+ if state.is_world_process_zero:
116
+ total_time = time.time() - self.start_time
117
+ total_str = time.strftime("%H:%M:%S", time.gmtime(total_time))
118
+ print("\n" + "=" * 70)
119
+ print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!")
120
+ print(f" Total time: {total_str}")
121
+ print("=" * 70 + "\n")
122
+
123
+
124
+ # ============================================================
125
+ # データセット変換
126
+ # ============================================================
127
+ def convert_glaive_to_chatml(example: dict) -> dict:
128
+ parts = []
129
+
130
+ if example.get("system"):
131
+ parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>")
132
+
133
+ chat = example.get("chat", "")
134
+ if chat:
135
+ current_role = None
136
+ current_content = []
137
+
138
+ for line in chat.split("\n"):
139
+ line = line.strip()
140
+ if line.startswith("USER:"):
141
+ if current_role and current_content:
142
+ content = "\n".join(current_content).strip()
143
+ if content:
144
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
145
+ current_role = "user"
146
+ current_content = [line[5:].strip()]
147
+ elif line.startswith("ASSISTANT:"):
148
+ if current_role and current_content:
149
+ content = "\n".join(current_content).strip()
150
+ if content:
151
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
152
+ current_role = "assistant"
153
+ current_content = [line[10:].strip()]
154
+ elif current_role:
155
+ current_content.append(line)
156
+
157
+ if current_role and current_content:
158
+ content = "\n".join(current_content).strip()
159
+ if content:
160
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
161
+
162
+ return {"text": "\n".join(parts)}
163
+
164
+
165
+ def load_and_prepare_dataset():
166
+ print(f"\nLoading dataset: {DATASET_NAME}")
167
+
168
+ dataset = load_dataset(DATASET_NAME, split="train")
169
+ print(f"Original size: {len(dataset)} examples")
170
+
171
+ dataset = dataset.map(
172
+ convert_glaive_to_chatml,
173
+ remove_columns=dataset.column_names,
174
+ num_proc=4,
175
+ desc="Converting"
176
+ )
177
+
178
+ dataset = dataset.filter(lambda x: len(x["text"]) > 50)
179
+ print(f"After filtering: {len(dataset)} examples")
180
+
181
+ dataset = dataset.shuffle(seed=42)
182
+ split = dataset.train_test_split(test_size=0.02, seed=42)
183
+
184
+ print(f"Train: {len(split['train'])}, Test: {len(split['test'])}")
185
+ return split
186
+
187
+
188
+ # ============================================================
189
+ # 学習パラメータ(マルチGPU最適化)
190
+ # ============================================================
191
+ num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
192
+
193
+ training_args = TrainingArguments(
194
+ output_dir=CHECKPOINT_DIR,
195
+
196
+ num_train_epochs=2,
197
+
198
+ # マルチGPU: バッチサイズを上げる
199
+ per_device_train_batch_size=4, # 1GPUあたり4
200
+ per_device_eval_batch_size=4,
201
+ gradient_accumulation_steps=4, # 有効バッチ: 4*4*num_gpus
202
+
203
+ learning_rate=1e-4,
204
+ weight_decay=0.01,
205
+ warmup_ratio=0.03,
206
+ lr_scheduler_type="cosine",
207
+
208
+ optim="paged_adamw_8bit",
209
+ fp16=False,
210
+ bf16=True,
211
+ max_grad_norm=0.3,
212
+
213
+ logging_steps=10,
214
+ save_steps=500,
215
+ save_total_limit=3,
216
+ eval_strategy="steps",
217
+ eval_steps=500,
218
+
219
+ report_to="none",
220
+ group_by_length=True,
221
+ gradient_checkpointing=True,
222
+
223
+ # マルチGPU設定
224
+ ddp_find_unused_parameters=False,
225
+ dataloader_num_workers=4,
226
+
227
+ save_safetensors=True,
228
+ )
229
+
230
+
231
+ # ============================================================
232
+ # メイン
233
+ # ============================================================
234
+ def main():
235
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
236
+ is_main = local_rank == 0
237
+
238
+ if is_main:
239
+ print("\n" + "=" * 70)
240
+ print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training")
241
+ print(" Multi-GPU Version")
242
+ print("=" * 70)
243
+ print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
244
+ print(f"GPUs available: {torch.cuda.device_count()}")
245
+ for i in range(torch.cuda.device_count()):
246
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
247
+ print("=" * 70 + "\n")
248
+
249
+ # データセット
250
+ dataset = load_and_prepare_dataset()
251
+
252
+ # トークナイザー
253
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
254
+ tokenizer.padding_side = "right"
255
+ if tokenizer.pad_token is None:
256
+ tokenizer.pad_token = tokenizer.eos_token
257
+
258
+ # モデル
259
+ if is_main:
260
+ print(f"\nLoading model: {BASE_MODEL}")
261
+
262
+ model = AutoModelForCausalLM.from_pretrained(
263
+ BASE_MODEL,
264
+ quantization_config=bnb_config,
265
+ device_map={"": local_rank}, # 各GPUに配置
266
+ attn_implementation="sdpa",
267
+ trust_remote_code=True,
268
+ )
269
+
270
+ model = prepare_model_for_kbit_training(model)
271
+ model = get_peft_model(model, lora_config)
272
+
273
+ if is_main:
274
+ model.print_trainable_parameters()
275
+
276
+ # Trainer
277
+ trainer = SFTTrainer(
278
+ model=model,
279
+ train_dataset=dataset["train"],
280
+ eval_dataset=dataset["test"],
281
+ args=training_args,
282
+ peft_config=lora_config,
283
+ processing_class=tokenizer,
284
+ max_seq_length=2048,
285
+ packing=True,
286
+ dataset_text_field="text",
287
+ callbacks=[VerboseLoggingCallback()],
288
+ )
289
+
290
+ # チェックポイント再開
291
+ resume_from = None
292
+ if os.path.exists(CHECKPOINT_DIR):
293
+ checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
294
+ if checkpoints:
295
+ latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
296
+ resume_from = os.path.join(CHECKPOINT_DIR, latest)
297
+ if is_main:
298
+ print(f"\n📂 Resuming from: {resume_from}")
299
+
300
+ # 学習
301
+ trainer.train(resume_from_checkpoint=resume_from)
302
+
303
+ # 保存(メインプロセスのみ)
304
+ if is_main:
305
+ print(f"\nSaving to {FINAL_OUTPUT_DIR}...")
306
+ trainer.save_model(FINAL_OUTPUT_DIR)
307
+ tokenizer.save_pretrained(FINAL_OUTPUT_DIR)
308
+
309
+ print(f"\nUploading to: {OUTPUT_MODEL_ID}")
310
+ try:
311
+ trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True)
312
+ tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True)
313
+ print(f"✅ Uploaded: https://huggingface.co/{OUTPUT_MODEL_ID}")
314
+ except Exception as e:
315
+ print(f"⚠️ Upload failed: {e}")
316
+
317
+ print("\n🎉 Training complete!")
318
+
319
+
320
+ if __name__ == "__main__":
321
+ main()