Spaces:
Runtime error
Runtime error
Hajime MATSUMOTO
L40S optimization: batch 8, disable gradient checkpointing, parallel dataloader
113833d
| #!/usr/bin/env python3 | |
| """ | |
| Qwen2.5-7B-Instruct + glaive-function-calling-v2 QLoRA学習スクリプト | |
| 目的: Function Calling能力の強化 | |
| データセット: glaiveai/glaive-function-calling-v2 (113k samples) | |
| """ | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TrainingArguments, | |
| ) | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from trl import SFTTrainer | |
| from transformers.trainer_callback import TrainerCallback | |
| # ============================================================ | |
| # 設定 | |
| # ============================================================ | |
| BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" | |
| OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora-cloud" # クラウド版 | |
| DATASET_NAME = "glaiveai/glaive-function-calling-v2" | |
| # チェックポイント設定 | |
| CHECKPOINT_DIR = "./checkpoints" | |
| FINAL_OUTPUT_DIR = "./output/final" | |
| # ============================================================ | |
| # QLoRA量子化設定 | |
| # ============================================================ | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # ============================================================ | |
| # LoRA設定 | |
| # ============================================================ | |
| lora_config = LoraConfig( | |
| r=64, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj" | |
| ], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| # ============================================================ | |
| # カスタムコールバック: 定期ログ出力 | |
| # ============================================================ | |
| class VerboseLoggingCallback(TrainerCallback): | |
| """詳細なログ出力用コールバック""" | |
| def __init__(self): | |
| self.start_time = None | |
| self.last_log_time = None | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| self.start_time = time.time() | |
| self.last_log_time = self.start_time | |
| print("\n" + "=" * 70) | |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started") | |
| print(f" Total steps: {state.max_steps}") | |
| print(f" Epochs: {args.num_train_epochs}") | |
| print(f" Batch size: {args.per_device_train_batch_size} x {args.gradient_accumulation_steps}") | |
| print("=" * 70 + "\n") | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs is None: | |
| return | |
| current_time = time.time() | |
| elapsed = current_time - self.start_time | |
| elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed)) | |
| # 進捗計算 | |
| progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0 | |
| # ETA計算 | |
| if state.global_step > 0: | |
| time_per_step = elapsed / state.global_step | |
| remaining_steps = state.max_steps - state.global_step | |
| eta_seconds = time_per_step * remaining_steps | |
| eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds)) | |
| else: | |
| eta_str = "calculating..." | |
| # ログ出力 | |
| loss = logs.get("loss", "N/A") | |
| lr = logs.get("learning_rate", "N/A") | |
| loss_str = f"{loss:.4f}" if isinstance(loss, float) else str(loss) | |
| lr_str = f"{lr:.2e}" if isinstance(lr, float) else str(lr) | |
| print(f"[{datetime.now().strftime('%H:%M:%S')}] " | |
| f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | " | |
| f"Loss: {loss_str} | " | |
| f"LR: {lr_str} | " | |
| f"Elapsed: {elapsed_str} | ETA: {eta_str}") | |
| # GPU メモリ使用量(10ステップごと) | |
| if state.global_step % 100 == 0 and torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1e9 | |
| reserved = torch.cuda.memory_reserved() / 1e9 | |
| print(f" GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") | |
| def on_save(self, args, state, control, **kwargs): | |
| print(f"\n[{datetime.now().strftime('%H:%M:%S')}] " | |
| f"💾 Checkpoint saved at step {state.global_step}\n") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| total_time = time.time() - self.start_time | |
| total_str = time.strftime("%H:%M:%S", time.gmtime(total_time)) | |
| print("\n" + "=" * 70) | |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!") | |
| print(f" Total time: {total_str}") | |
| print(f" Final step: {state.global_step}") | |
| print("=" * 70 + "\n") | |
| # ============================================================ | |
| # データセット変換 | |
| # ============================================================ | |
| def convert_glaive_to_chatml(example: dict) -> dict: | |
| """ | |
| glaive-function-calling-v2形式をChatML形式に変換 | |
| 元データ形式: | |
| - system: 関数定義を含むシステムプロンプト | |
| - chat: "USER: ... ASSISTANT: ..." 形式の会話 | |
| """ | |
| parts = [] | |
| # システムプロンプト | |
| if example.get("system"): | |
| parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>") | |
| # 会話を解析 | |
| chat = example.get("chat", "") | |
| if chat: | |
| # "USER:" と "ASSISTANT:" で分割 | |
| # 複数ターンに対応 | |
| current_role = None | |
| current_content = [] | |
| for line in chat.split("\n"): | |
| line = line.strip() | |
| if line.startswith("USER:"): | |
| # 前のメッセージを保存 | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| current_role = "user" | |
| current_content = [line[5:].strip()] # "USER:" を除去 | |
| elif line.startswith("ASSISTANT:"): | |
| # 前のメッセージを保存 | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| current_role = "assistant" | |
| current_content = [line[10:].strip()] # "ASSISTANT:" を除去 | |
| elif current_role: | |
| current_content.append(line) | |
| # 最後のメッセージを保存 | |
| if current_role and current_content: | |
| content = "\n".join(current_content).strip() | |
| if content: | |
| parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>") | |
| return {"text": "\n".join(parts)} | |
| def load_and_prepare_dataset(): | |
| """データセットを読み込んで前処理""" | |
| print(f"\n{'=' * 60}") | |
| print(f"Loading dataset: {DATASET_NAME}") | |
| print(f"{'=' * 60}") | |
| # データセット読み込み | |
| dataset = load_dataset(DATASET_NAME, split="train") | |
| print(f"Original size: {len(dataset)} examples") | |
| # 変換 | |
| print("Converting to ChatML format...") | |
| dataset = dataset.map( | |
| convert_glaive_to_chatml, | |
| remove_columns=dataset.column_names, | |
| num_proc=4, | |
| desc="Converting" | |
| ) | |
| # 空のサンプルをフィルタ | |
| dataset = dataset.filter(lambda x: len(x["text"]) > 50) | |
| print(f"After filtering: {len(dataset)} examples") | |
| # サンプル表示 | |
| print("\n--- Sample data ---") | |
| sample = dataset[0]["text"] | |
| print(sample[:500] + "..." if len(sample) > 500 else sample) | |
| print("--- End sample ---\n") | |
| # シャッフルしてTrain/Test分割 | |
| dataset = dataset.shuffle(seed=42) | |
| split = dataset.train_test_split(test_size=0.02, seed=42) | |
| print(f"Train: {len(split['train'])} examples") | |
| print(f"Test: {len(split['test'])} examples") | |
| return split | |
| # ============================================================ | |
| # 学習パラメータ | |
| # ============================================================ | |
| training_args = TrainingArguments( | |
| output_dir=CHECKPOINT_DIR, | |
| # エポック・ステップ | |
| num_train_epochs=1, | |
| max_steps=-1, # -1 = エポックベース | |
| # バッチサイズ (L40S 48GB - 攻めた設定) | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| gradient_accumulation_steps=2, # 有効バッチサイズ: 8*2=16 | |
| # 学習率 (1エポックで収束するよう高め) | |
| learning_rate=2e-4, | |
| weight_decay=0.01, | |
| warmup_ratio=0.03, | |
| lr_scheduler_type="cosine", | |
| # 最適化 | |
| optim="paged_adamw_8bit", | |
| fp16=False, | |
| bf16=True, | |
| max_grad_norm=0.3, | |
| # ログ・保存(重要!) | |
| logging_steps=10, # 10ステップごとにログ | |
| save_steps=500, # 500ステップごとにチェックポイント | |
| save_total_limit=3, # 最新3つのチェックポイントを保持 | |
| eval_strategy="steps", | |
| eval_steps=500, # 500ステップごとに評価 | |
| # その他 | |
| report_to="none", | |
| group_by_length=True, | |
| gradient_checkpointing=False, # L40Sは48GBあるのでオフで高速化 | |
| torch_compile=False, # 初回コンパイル時間を避ける | |
| dataloader_num_workers=4, | |
| dataloader_pin_memory=True, | |
| # 再開用 | |
| save_safetensors=True, | |
| load_best_model_at_end=False, | |
| ) | |
| # ============================================================ | |
| # メイン | |
| # ============================================================ | |
| def main(): | |
| print("\n" + "=" * 70) | |
| print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training") | |
| print("=" * 70) | |
| print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"Base model: {BASE_MODEL}") | |
| print(f"Dataset: {DATASET_NAME}") | |
| print(f"Output: {OUTPUT_MODEL_ID}") | |
| print("=" * 70 + "\n") | |
| # GPU確認 | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| print(f"GPU: {gpu_name}") | |
| print(f"VRAM: {gpu_mem:.1f} GB") | |
| else: | |
| print("ERROR: No GPU available!") | |
| sys.exit(1) | |
| # データセット読み込み | |
| dataset = load_and_prepare_dataset() | |
| # トークナイザー読み込み | |
| print(f"\nLoading tokenizer: {BASE_MODEL}") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) | |
| tokenizer.padding_side = "right" | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # モデル読み込み (4bit量子化) | |
| print(f"\nLoading model: {BASE_MODEL} (4-bit quantized)") | |
| print("This may take a few minutes...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| ) | |
| # 学習準備 | |
| print("\nPreparing model for training...") | |
| model = prepare_model_for_kbit_training(model) | |
| model = get_peft_model(model, lora_config) | |
| print("\nTrainable parameters:") | |
| model.print_trainable_parameters() | |
| # SFTTrainer設定 | |
| trainer = SFTTrainer( | |
| model=model, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| args=training_args, | |
| peft_config=lora_config, | |
| tokenizer=tokenizer, | |
| max_seq_length=1024, # 高速化のため短縮 | |
| packing=False, # flash attention不要で高速化 | |
| dataset_text_field="text", | |
| callbacks=[VerboseLoggingCallback()], | |
| ) | |
| # チェックポイントからの再開確認 | |
| resume_from = None | |
| if os.path.exists(CHECKPOINT_DIR): | |
| checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")] | |
| if checkpoints: | |
| latest = max(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| resume_from = os.path.join(CHECKPOINT_DIR, latest) | |
| print(f"\n📂 Found checkpoint: {resume_from}") | |
| print(" Resuming from checkpoint...") | |
| # 学習実行 | |
| print("\n" + "=" * 70) | |
| print("Starting training...") | |
| print("=" * 70) | |
| trainer.train(resume_from_checkpoint=resume_from) | |
| # 最終モデル保存 | |
| print(f"\nSaving final model to {FINAL_OUTPUT_DIR}...") | |
| trainer.save_model(FINAL_OUTPUT_DIR) | |
| tokenizer.save_pretrained(FINAL_OUTPUT_DIR) | |
| # HFにアップロード | |
| print(f"\nUploading to HuggingFace: {OUTPUT_MODEL_ID}") | |
| try: | |
| trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True) | |
| tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True) | |
| print(f"✅ Model uploaded to: https://huggingface.co/{OUTPUT_MODEL_ID}") | |
| except Exception as e: | |
| print(f"⚠️ Upload failed: {e}") | |
| print(" Model saved locally. Please upload manually.") | |
| print("\n" + "=" * 70) | |
| print("🎉 Training complete!") | |
| print("=" * 70) | |
| if __name__ == "__main__": | |
| main() | |