glaive-7b-training / train_multi_gpu.py
Hajime MATSUMOTO
Use single GPU with larger batch size for L40S 48GB
b491772
#!/usr/bin/env python3
"""
Qwen2.5-7B + glaive-function-calling-v2 QLoRA学習スクリプト
マルチGPU対応版 (4xL40S等)
実行方法:
accelerate launch --num_processes 4 train_multi_gpu.py
"""
import os
import sys
import time
from datetime import datetime
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from transformers.trainer_callback import TrainerCallback
# ============================================================
# 設定
# ============================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora"
DATASET_NAME = "glaiveai/glaive-function-calling-v2"
CHECKPOINT_DIR = "./checkpoints"
FINAL_OUTPUT_DIR = "./output/final"
# ============================================================
# QLoRA量子化設定
# ============================================================
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# ============================================================
# LoRA設定
# ============================================================
lora_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
bias="none",
task_type="CAUSAL_LM",
)
# ============================================================
# カスタムコールバック
# ============================================================
class VerboseLoggingCallback(TrainerCallback):
def __init__(self):
self.start_time = None
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
if state.is_world_process_zero:
print("\n" + "=" * 70)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started")
print(f" Total steps: {state.max_steps}")
print(f" Num GPUs: {args.world_size}")
print(f" Per device batch: {args.per_device_train_batch_size}")
print(f" Gradient accum: {args.gradient_accumulation_steps}")
print(f" Effective batch: {args.per_device_train_batch_size * args.gradient_accumulation_steps * args.world_size}")
print("=" * 70 + "\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None or not state.is_world_process_zero:
return
current_time = time.time()
elapsed = current_time - self.start_time
elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed))
progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0
if state.global_step > 0:
time_per_step = elapsed / state.global_step
remaining_steps = state.max_steps - state.global_step
eta_seconds = time_per_step * remaining_steps
eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
else:
eta_str = "calculating..."
loss = logs.get("loss", "N/A")
lr = logs.get("learning_rate", "N/A")
print(f"[{datetime.now().strftime('%H:%M:%S')}] "
f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | "
f"Loss: {loss:.4f if isinstance(loss, float) else loss} | "
f"LR: {lr:.2e if isinstance(lr, float) else lr} | "
f"Elapsed: {elapsed_str} | ETA: {eta_str}")
def on_save(self, args, state, control, **kwargs):
if state.is_world_process_zero:
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] "
f"💾 Checkpoint saved at step {state.global_step}\n")
def on_train_end(self, args, state, control, **kwargs):
if state.is_world_process_zero:
total_time = time.time() - self.start_time
total_str = time.strftime("%H:%M:%S", time.gmtime(total_time))
print("\n" + "=" * 70)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!")
print(f" Total time: {total_str}")
print("=" * 70 + "\n")
# ============================================================
# データセット変換
# ============================================================
def convert_glaive_to_chatml(example: dict) -> dict:
parts = []
if example.get("system"):
parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>")
chat = example.get("chat", "")
if chat:
current_role = None
current_content = []
for line in chat.split("\n"):
line = line.strip()
if line.startswith("USER:"):
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
current_role = "user"
current_content = [line[5:].strip()]
elif line.startswith("ASSISTANT:"):
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
current_role = "assistant"
current_content = [line[10:].strip()]
elif current_role:
current_content.append(line)
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
return {"text": "\n".join(parts)}
def load_and_prepare_dataset():
print(f"\nLoading dataset: {DATASET_NAME}")
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Original size: {len(dataset)} examples")
dataset = dataset.map(
convert_glaive_to_chatml,
remove_columns=dataset.column_names,
num_proc=4,
desc="Converting"
)
dataset = dataset.filter(lambda x: len(x["text"]) > 50)
print(f"After filtering: {len(dataset)} examples")
dataset = dataset.shuffle(seed=42)
split = dataset.train_test_split(test_size=0.02, seed=42)
print(f"Train: {len(split['train'])}, Test: {len(split['test'])}")
return split
# ============================================================
# 学習パラメータ(マルチGPU最適化)
# ============================================================
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
training_args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
num_train_epochs=2,
# マルチGPU: L40Sは48GB VRAMなのでバッチサイズを上げる
per_device_train_batch_size=8, # 1GPUあたり8 (L40S 48GB)
per_device_eval_batch_size=8,
gradient_accumulation_steps=2, # 有効バッチ: 8*2*4=64
learning_rate=1e-4,
weight_decay=0.01,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
optim="paged_adamw_8bit",
fp16=False,
bf16=True,
max_grad_norm=0.3,
logging_steps=10,
save_steps=500,
save_total_limit=3,
eval_strategy="steps",
eval_steps=500,
report_to="none",
group_by_length=True,
gradient_checkpointing=True,
# マルチGPU設定
ddp_find_unused_parameters=False,
dataloader_num_workers=4,
save_safetensors=True,
)
# ============================================================
# メイン
# ============================================================
def main():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_main = local_rank == 0
if is_main:
print("\n" + "=" * 70)
print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training")
print(" Multi-GPU Version")
print("=" * 70)
print(f"Start: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"GPUs available: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
print("=" * 70 + "\n")
# データセット
dataset = load_and_prepare_dataset()
# トークナイザー
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# モデル
if is_main:
print(f"\nLoading model: {BASE_MODEL}")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map={"": local_rank}, # 各GPUに配置
attn_implementation="sdpa",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
if is_main:
model.print_trainable_parameters()
# Trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
peft_config=lora_config,
processing_class=tokenizer,
max_seq_length=2048,
packing=True,
dataset_text_field="text",
callbacks=[VerboseLoggingCallback()],
)
# チェックポイント再開
resume_from = None
if os.path.exists(CHECKPOINT_DIR):
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
if checkpoints:
latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
resume_from = os.path.join(CHECKPOINT_DIR, latest)
if is_main:
print(f"\n📂 Resuming from: {resume_from}")
# 学習
trainer.train(resume_from_checkpoint=resume_from)
# 保存(メインプロセスのみ)
if is_main:
print(f"\nSaving to {FINAL_OUTPUT_DIR}...")
trainer.save_model(FINAL_OUTPUT_DIR)
tokenizer.save_pretrained(FINAL_OUTPUT_DIR)
print(f"\nUploading to: {OUTPUT_MODEL_ID}")
try:
trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True)
tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True)
print(f"✅ Uploaded: https://huggingface.co/{OUTPUT_MODEL_ID}")
except Exception as e:
print(f"⚠️ Upload failed: {e}")
print("\n🎉 Training complete!")
if __name__ == "__main__":
main()