glaive-7b-training / train.py
Hajime MATSUMOTO
L40S optimization: batch 8, disable gradient checkpointing, parallel dataloader
113833d
#!/usr/bin/env python3
"""
Qwen2.5-7B-Instruct + glaive-function-calling-v2 QLoRA学習スクリプト
目的: Function Calling能力の強化
データセット: glaiveai/glaive-function-calling-v2 (113k samples)
"""
import os
import sys
import time
from datetime import datetime
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from transformers.trainer_callback import TrainerCallback
# ============================================================
# 設定
# ============================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora-cloud" # クラウド版
DATASET_NAME = "glaiveai/glaive-function-calling-v2"
# チェックポイント設定
CHECKPOINT_DIR = "./checkpoints"
FINAL_OUTPUT_DIR = "./output/final"
# ============================================================
# QLoRA量子化設定
# ============================================================
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# ============================================================
# LoRA設定
# ============================================================
lora_config = LoraConfig(
r=64,
lora_alpha=16,
lora_dropout=0.05,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
bias="none",
task_type="CAUSAL_LM",
)
# ============================================================
# カスタムコールバック: 定期ログ出力
# ============================================================
class VerboseLoggingCallback(TrainerCallback):
"""詳細なログ出力用コールバック"""
def __init__(self):
self.start_time = None
self.last_log_time = None
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
self.last_log_time = self.start_time
print("\n" + "=" * 70)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started")
print(f" Total steps: {state.max_steps}")
print(f" Epochs: {args.num_train_epochs}")
print(f" Batch size: {args.per_device_train_batch_size} x {args.gradient_accumulation_steps}")
print("=" * 70 + "\n")
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
current_time = time.time()
elapsed = current_time - self.start_time
elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed))
# 進捗計算
progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0
# ETA計算
if state.global_step > 0:
time_per_step = elapsed / state.global_step
remaining_steps = state.max_steps - state.global_step
eta_seconds = time_per_step * remaining_steps
eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
else:
eta_str = "calculating..."
# ログ出力
loss = logs.get("loss", "N/A")
lr = logs.get("learning_rate", "N/A")
loss_str = f"{loss:.4f}" if isinstance(loss, float) else str(loss)
lr_str = f"{lr:.2e}" if isinstance(lr, float) else str(lr)
print(f"[{datetime.now().strftime('%H:%M:%S')}] "
f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | "
f"Loss: {loss_str} | "
f"LR: {lr_str} | "
f"Elapsed: {elapsed_str} | ETA: {eta_str}")
# GPU メモリ使用量(10ステップごと)
if state.global_step % 100 == 0 and torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f" GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
def on_save(self, args, state, control, **kwargs):
print(f"\n[{datetime.now().strftime('%H:%M:%S')}] "
f"💾 Checkpoint saved at step {state.global_step}\n")
def on_train_end(self, args, state, control, **kwargs):
total_time = time.time() - self.start_time
total_str = time.strftime("%H:%M:%S", time.gmtime(total_time))
print("\n" + "=" * 70)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!")
print(f" Total time: {total_str}")
print(f" Final step: {state.global_step}")
print("=" * 70 + "\n")
# ============================================================
# データセット変換
# ============================================================
def convert_glaive_to_chatml(example: dict) -> dict:
"""
glaive-function-calling-v2形式をChatML形式に変換
元データ形式:
- system: 関数定義を含むシステムプロンプト
- chat: "USER: ... ASSISTANT: ..." 形式の会話
"""
parts = []
# システムプロンプト
if example.get("system"):
parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>")
# 会話を解析
chat = example.get("chat", "")
if chat:
# "USER:" と "ASSISTANT:" で分割
# 複数ターンに対応
current_role = None
current_content = []
for line in chat.split("\n"):
line = line.strip()
if line.startswith("USER:"):
# 前のメッセージを保存
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
current_role = "user"
current_content = [line[5:].strip()] # "USER:" を除去
elif line.startswith("ASSISTANT:"):
# 前のメッセージを保存
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
current_role = "assistant"
current_content = [line[10:].strip()] # "ASSISTANT:" を除去
elif current_role:
current_content.append(line)
# 最後のメッセージを保存
if current_role and current_content:
content = "\n".join(current_content).strip()
if content:
parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
return {"text": "\n".join(parts)}
def load_and_prepare_dataset():
"""データセットを読み込んで前処理"""
print(f"\n{'=' * 60}")
print(f"Loading dataset: {DATASET_NAME}")
print(f"{'=' * 60}")
# データセット読み込み
dataset = load_dataset(DATASET_NAME, split="train")
print(f"Original size: {len(dataset)} examples")
# 変換
print("Converting to ChatML format...")
dataset = dataset.map(
convert_glaive_to_chatml,
remove_columns=dataset.column_names,
num_proc=4,
desc="Converting"
)
# 空のサンプルをフィルタ
dataset = dataset.filter(lambda x: len(x["text"]) > 50)
print(f"After filtering: {len(dataset)} examples")
# サンプル表示
print("\n--- Sample data ---")
sample = dataset[0]["text"]
print(sample[:500] + "..." if len(sample) > 500 else sample)
print("--- End sample ---\n")
# シャッフルしてTrain/Test分割
dataset = dataset.shuffle(seed=42)
split = dataset.train_test_split(test_size=0.02, seed=42)
print(f"Train: {len(split['train'])} examples")
print(f"Test: {len(split['test'])} examples")
return split
# ============================================================
# 学習パラメータ
# ============================================================
training_args = TrainingArguments(
output_dir=CHECKPOINT_DIR,
# エポック・ステップ
num_train_epochs=1,
max_steps=-1, # -1 = エポックベース
# バッチサイズ (L40S 48GB - 攻めた設定)
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
gradient_accumulation_steps=2, # 有効バッチサイズ: 8*2=16
# 学習率 (1エポックで収束するよう高め)
learning_rate=2e-4,
weight_decay=0.01,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
# 最適化
optim="paged_adamw_8bit",
fp16=False,
bf16=True,
max_grad_norm=0.3,
# ログ・保存(重要!)
logging_steps=10, # 10ステップごとにログ
save_steps=500, # 500ステップごとにチェックポイント
save_total_limit=3, # 最新3つのチェックポイントを保持
eval_strategy="steps",
eval_steps=500, # 500ステップごとに評価
# その他
report_to="none",
group_by_length=True,
gradient_checkpointing=False, # L40Sは48GBあるのでオフで高速化
torch_compile=False, # 初回コンパイル時間を避ける
dataloader_num_workers=4,
dataloader_pin_memory=True,
# 再開用
save_safetensors=True,
load_best_model_at_end=False,
)
# ============================================================
# メイン
# ============================================================
def main():
print("\n" + "=" * 70)
print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training")
print("=" * 70)
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Base model: {BASE_MODEL}")
print(f"Dataset: {DATASET_NAME}")
print(f"Output: {OUTPUT_MODEL_ID}")
print("=" * 70 + "\n")
# GPU確認
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name}")
print(f"VRAM: {gpu_mem:.1f} GB")
else:
print("ERROR: No GPU available!")
sys.exit(1)
# データセット読み込み
dataset = load_and_prepare_dataset()
# トークナイザー読み込み
print(f"\nLoading tokenizer: {BASE_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# モデル読み込み (4bit量子化)
print(f"\nLoading model: {BASE_MODEL} (4-bit quantized)")
print("This may take a few minutes...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa",
trust_remote_code=True,
)
# 学習準備
print("\nPreparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
print("\nTrainable parameters:")
model.print_trainable_parameters()
# SFTTrainer設定
trainer = SFTTrainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
peft_config=lora_config,
tokenizer=tokenizer,
max_seq_length=1024, # 高速化のため短縮
packing=False, # flash attention不要で高速化
dataset_text_field="text",
callbacks=[VerboseLoggingCallback()],
)
# チェックポイントからの再開確認
resume_from = None
if os.path.exists(CHECKPOINT_DIR):
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
if checkpoints:
latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
resume_from = os.path.join(CHECKPOINT_DIR, latest)
print(f"\n📂 Found checkpoint: {resume_from}")
print(" Resuming from checkpoint...")
# 学習実行
print("\n" + "=" * 70)
print("Starting training...")
print("=" * 70)
trainer.train(resume_from_checkpoint=resume_from)
# 最終モデル保存
print(f"\nSaving final model to {FINAL_OUTPUT_DIR}...")
trainer.save_model(FINAL_OUTPUT_DIR)
tokenizer.save_pretrained(FINAL_OUTPUT_DIR)
# HFにアップロード
print(f"\nUploading to HuggingFace: {OUTPUT_MODEL_ID}")
try:
trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True)
tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True)
print(f"✅ Model uploaded to: https://huggingface.co/{OUTPUT_MODEL_ID}")
except Exception as e:
print(f"⚠️ Upload failed: {e}")
print(" Model saved locally. Please upload manually.")
print("\n" + "=" * 70)
print("🎉 Training complete!")
print("=" * 70)
if __name__ == "__main__":
main()