Hajime MATSUMOTO commited on
Commit
6d8a316
·
1 Parent(s): d6f9891

Add 7B QLoRA training + AWQ quantization scripts

Browse files
Files changed (5) hide show
  1. Dockerfile +24 -0
  2. README.md +26 -4
  3. merge_and_quantize.py +204 -0
  4. requirements.txt +10 -0
  5. train.py +376 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel
2
+
3
+ WORKDIR /app
4
+
5
+ # 基本パッケージ
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ curl \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Python依存関係
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # 学習スクリプト
16
+ COPY train.py .
17
+
18
+ # HFトークンは環境変数で渡す
19
+ ENV HF_TOKEN=""
20
+ ENV TRANSFORMERS_CACHE=/app/cache
21
+ ENV HF_HOME=/app/cache
22
+
23
+ # 学習実行
24
+ CMD ["python", "train.py"]
README.md CHANGED
@@ -1,10 +1,32 @@
1
  ---
2
- title: Glaive 7b Training
3
- emoji: 🐢
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Glaive 7B Training
3
+ emoji: 🚀
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ # Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training
11
+
12
+ Function Calling能力強化のための学習
13
+
14
+ ## 概要
15
+
16
+ - **ベースモデル**: Qwen/Qwen2.5-7B-Instruct
17
+ - **データセット**: glaiveai/glaive-function-calling-v2 (約113k samples)
18
+ - **手法**: QLoRA (4-bit量子化 + LoRA)
19
+ - **出力**: hajimemat/qwen2.5-7b-glaive-fc-lora
20
+
21
+ ## 特徴
22
+
23
+ - 10ステップごとにログ出力(Loss, LR, ETA)
24
+ - 500ステップごとにチェックポイント保存
25
+ - 中断しても自動再開対応
26
+ - 完了時にHFへ自動アップロード
27
+
28
+ ## 学習後
29
+
30
+ 1. LoRAをベースモデルにマージ
31
+ 2. AWQ量子化
32
+ 3. vLLMでデプロイ
merge_and_quantize.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LoRAマージ + AWQ量子化スクリプト
4
+
5
+ 学習完了後に実行:
6
+ 1. LoRAアダプターをベースモデルにマージ
7
+ 2. AWQ量子化(4bit)
8
+ 3. HuggingFaceにアップロード
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import shutil
14
+ from datetime import datetime
15
+
16
+ import torch
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from peft import PeftModel
19
+ from awq import AutoAWQForCausalLM
20
+
21
+ # ============================================================
22
+ # 設定
23
+ # ============================================================
24
+ BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
25
+ LORA_MODEL = "hajimemat/qwen2.5-7b-glaive-fc-lora" # 学習済みLoRA
26
+
27
+ # 出力先
28
+ MERGED_MODEL_DIR = "./merged_model"
29
+ QUANTIZED_MODEL_DIR = "./quantized_model"
30
+ OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-awq"
31
+
32
+ # AWQ量子化設定
33
+ AWQ_CONFIG = {
34
+ "zero_point": True,
35
+ "q_group_size": 128,
36
+ "w_bit": 4,
37
+ "version": "GEMM"
38
+ }
39
+
40
+
41
+ def step1_merge_lora():
42
+ """Step 1: LoRAをベースモデルにマージ"""
43
+ print("\n" + "=" * 60)
44
+ print("Step 1: Merging LoRA adapter to base model")
45
+ print("=" * 60)
46
+
47
+ print(f"Base model: {BASE_MODEL}")
48
+ print(f"LoRA model: {LORA_MODEL}")
49
+
50
+ # ベースモデル読み込み
51
+ print("\nLoading base model...")
52
+ base_model = AutoModelForCausalLM.from_pretrained(
53
+ BASE_MODEL,
54
+ torch_dtype=torch.bfloat16,
55
+ device_map="auto",
56
+ trust_remote_code=True,
57
+ )
58
+
59
+ # トークナイザー読み込み
60
+ print("Loading tokenizer...")
61
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
62
+
63
+ # LoRAアダプター適用
64
+ print("Loading LoRA adapter...")
65
+ model = PeftModel.from_pretrained(base_model, LORA_MODEL)
66
+
67
+ # マージ
68
+ print("Merging LoRA weights...")
69
+ model = model.merge_and_unload()
70
+
71
+ # 保存
72
+ print(f"Saving merged model to {MERGED_MODEL_DIR}...")
73
+ model.save_pretrained(MERGED_MODEL_DIR, safe_serialization=True)
74
+ tokenizer.save_pretrained(MERGED_MODEL_DIR)
75
+
76
+ # メモリ解放
77
+ del model
78
+ del base_model
79
+ torch.cuda.empty_cache()
80
+
81
+ print("✅ Step 1 complete: LoRA merged")
82
+ return MERGED_MODEL_DIR
83
+
84
+
85
+ def step2_quantize_awq(merged_model_path):
86
+ """Step 2: AWQ量子化"""
87
+ print("\n" + "=" * 60)
88
+ print("Step 2: AWQ Quantization (4-bit)")
89
+ print("=" * 60)
90
+
91
+ print(f"Input model: {merged_model_path}")
92
+ print(f"AWQ config: {AWQ_CONFIG}")
93
+
94
+ # モデル読み込み
95
+ print("\nLoading merged model for quantization...")
96
+ model = AutoAWQForCausalLM.from_pretrained(
97
+ merged_model_path,
98
+ trust_remote_code=True,
99
+ safetensors=True,
100
+ )
101
+ tokenizer = AutoTokenizer.from_pretrained(merged_model_path)
102
+
103
+ # 量子化用キャリブレーションデータ
104
+ # シンプルなサンプルで十分
105
+ calib_data = [
106
+ "Hello, how are you today?",
107
+ "What is the weather like?",
108
+ "Can you help me with coding?",
109
+ "Search for files containing docker",
110
+ "List all repositories in the project",
111
+ "Find the definition of the function main",
112
+ "こんにちは、今日の天気はどうですか?",
113
+ "プロジェクトのファイル構成を教えてください",
114
+ ]
115
+
116
+ # 量子化実行
117
+ print("\nQuantizing model (this may take a while)...")
118
+ model.quantize(
119
+ tokenizer,
120
+ quant_config=AWQ_CONFIG,
121
+ calib_data=calib_data,
122
+ )
123
+
124
+ # 保存
125
+ print(f"\nSaving quantized model to {QUANTIZED_MODEL_DIR}...")
126
+ model.save_quantized(QUANTIZED_MODEL_DIR, safetensors=True)
127
+ tokenizer.save_pretrained(QUANTIZED_MODEL_DIR)
128
+
129
+ print("✅ Step 2 complete: AWQ quantization done")
130
+ return QUANTIZED_MODEL_DIR
131
+
132
+
133
+ def step3_upload_to_hub(quantized_model_path):
134
+ """Step 3: HuggingFaceにアップロード"""
135
+ print("\n" + "=" * 60)
136
+ print("Step 3: Upload to HuggingFace Hub")
137
+ print("=" * 60)
138
+
139
+ print(f"Uploading to: {OUTPUT_MODEL_ID}")
140
+
141
+ from huggingface_hub import HfApi, upload_folder
142
+
143
+ api = HfApi()
144
+
145
+ # リポジトリ作成(存在しなければ)
146
+ try:
147
+ api.create_repo(OUTPUT_MODEL_ID, private=True, exist_ok=True)
148
+ except Exception as e:
149
+ print(f"Note: {e}")
150
+
151
+ # アップロード
152
+ print("Uploading files...")
153
+ upload_folder(
154
+ folder_path=quantized_model_path,
155
+ repo_id=OUTPUT_MODEL_ID,
156
+ repo_type="model",
157
+ )
158
+
159
+ print(f"✅ Step 3 complete: Uploaded to https://huggingface.co/{OUTPUT_MODEL_ID}")
160
+
161
+
162
+ def main():
163
+ print("\n" + "=" * 70)
164
+ print(" LoRA Merge + AWQ Quantization Pipeline")
165
+ print("=" * 70)
166
+ print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
167
+ print(f"Base: {BASE_MODEL}")
168
+ print(f"LoRA: {LORA_MODEL}")
169
+ print(f"Output: {OUTPUT_MODEL_ID}")
170
+ print("=" * 70)
171
+
172
+ # GPU確認
173
+ if torch.cuda.is_available():
174
+ print(f"\nGPU: {torch.cuda.get_device_name(0)}")
175
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
176
+ else:
177
+ print("WARNING: No GPU available, this will be slow!")
178
+
179
+ # Step 1: マージ
180
+ merged_path = step1_merge_lora()
181
+
182
+ # Step 2: 量子化
183
+ quantized_path = step2_quantize_awq(merged_path)
184
+
185
+ # Step 3: アップロード
186
+ step3_upload_to_hub(quantized_path)
187
+
188
+ # クリーンアップ(オプション)
189
+ print("\n" + "=" * 60)
190
+ print("Cleanup")
191
+ print("=" * 60)
192
+ cleanup = input("Delete intermediate files? (merged_model/) [y/N]: ").strip().lower()
193
+ if cleanup == 'y':
194
+ shutil.rmtree(MERGED_MODEL_DIR, ignore_errors=True)
195
+ print("Cleaned up merged_model/")
196
+
197
+ print("\n" + "=" * 70)
198
+ print("🎉 Pipeline complete!")
199
+ print(f"Model available at: https://huggingface.co/{OUTPUT_MODEL_ID}")
200
+ print("=" * 70)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.40.0
3
+ datasets>=2.18.0
4
+ peft>=0.10.0
5
+ trl>=0.8.0
6
+ bitsandbytes>=0.43.0
7
+ accelerate>=0.28.0
8
+ huggingface_hub>=0.22.0
9
+ safetensors>=0.4.0
10
+ autoawq>=0.2.0
train.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Qwen2.5-7B-Instruct + glaive-function-calling-v2 QLoRA学習スクリプト
4
+
5
+ 目的: Function Calling能力の強化
6
+ データセット: glaiveai/glaive-function-calling-v2 (113k samples)
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import time
12
+ from datetime import datetime
13
+
14
+ import torch
15
+ from datasets import load_dataset
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ BitsAndBytesConfig,
20
+ TrainingArguments,
21
+ )
22
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
23
+ from trl import SFTTrainer
24
+ from transformers.trainer_callback import TrainerCallback
25
+
26
+ # ============================================================
27
+ # 設定
28
+ # ============================================================
29
+ BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
30
+ OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-lora"
31
+ DATASET_NAME = "glaiveai/glaive-function-calling-v2"
32
+
33
+ # チェックポイント設定
34
+ CHECKPOINT_DIR = "./checkpoints"
35
+ FINAL_OUTPUT_DIR = "./output/final"
36
+
37
+ # ============================================================
38
+ # QLoRA量子化設定
39
+ # ============================================================
40
+ bnb_config = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.bfloat16,
43
+ bnb_4bit_quant_type="nf4",
44
+ bnb_4bit_use_double_quant=True,
45
+ )
46
+
47
+ # ============================================================
48
+ # LoRA設定
49
+ # ============================================================
50
+ lora_config = LoraConfig(
51
+ r=64,
52
+ lora_alpha=16,
53
+ lora_dropout=0.05,
54
+ target_modules=[
55
+ "q_proj", "k_proj", "v_proj", "o_proj",
56
+ "gate_proj", "up_proj", "down_proj"
57
+ ],
58
+ bias="none",
59
+ task_type="CAUSAL_LM",
60
+ )
61
+
62
+
63
+ # ============================================================
64
+ # カスタムコールバック: 定期ログ出力
65
+ # ============================================================
66
+ class VerboseLoggingCallback(TrainerCallback):
67
+ """詳細なログ出力用コールバック"""
68
+
69
+ def __init__(self):
70
+ self.start_time = None
71
+ self.last_log_time = None
72
+
73
+ def on_train_begin(self, args, state, control, **kwargs):
74
+ self.start_time = time.time()
75
+ self.last_log_time = self.start_time
76
+ print("\n" + "=" * 70)
77
+ print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training started")
78
+ print(f" Total steps: {state.max_steps}")
79
+ print(f" Epochs: {args.num_train_epochs}")
80
+ print(f" Batch size: {args.per_device_train_batch_size} x {args.gradient_accumulation_steps}")
81
+ print("=" * 70 + "\n")
82
+
83
+ def on_log(self, args, state, control, logs=None, **kwargs):
84
+ if logs is None:
85
+ return
86
+
87
+ current_time = time.time()
88
+ elapsed = current_time - self.start_time
89
+ elapsed_str = time.strftime("%H:%M:%S", time.gmtime(elapsed))
90
+
91
+ # 進捗計算
92
+ progress = state.global_step / state.max_steps * 100 if state.max_steps > 0 else 0
93
+
94
+ # ETA計算
95
+ if state.global_step > 0:
96
+ time_per_step = elapsed / state.global_step
97
+ remaining_steps = state.max_steps - state.global_step
98
+ eta_seconds = time_per_step * remaining_steps
99
+ eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
100
+ else:
101
+ eta_str = "calculating..."
102
+
103
+ # ログ出力
104
+ loss = logs.get("loss", "N/A")
105
+ lr = logs.get("learning_rate", "N/A")
106
+
107
+ print(f"[{datetime.now().strftime('%H:%M:%S')}] "
108
+ f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%) | "
109
+ f"Loss: {loss:.4f if isinstance(loss, float) else loss} | "
110
+ f"LR: {lr:.2e if isinstance(lr, float) else lr} | "
111
+ f"Elapsed: {elapsed_str} | ETA: {eta_str}")
112
+
113
+ # GPU メモリ使用量(10ステップごと)
114
+ if state.global_step % 100 == 0 and torch.cuda.is_available():
115
+ allocated = torch.cuda.memory_allocated() / 1e9
116
+ reserved = torch.cuda.memory_reserved() / 1e9
117
+ print(f" GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
118
+
119
+ def on_save(self, args, state, control, **kwargs):
120
+ print(f"\n[{datetime.now().strftime('%H:%M:%S')}] "
121
+ f"💾 Checkpoint saved at step {state.global_step}\n")
122
+
123
+ def on_train_end(self, args, state, control, **kwargs):
124
+ total_time = time.time() - self.start_time
125
+ total_str = time.strftime("%H:%M:%S", time.gmtime(total_time))
126
+ print("\n" + "=" * 70)
127
+ print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Training completed!")
128
+ print(f" Total time: {total_str}")
129
+ print(f" Final step: {state.global_step}")
130
+ print("=" * 70 + "\n")
131
+
132
+
133
+ # ============================================================
134
+ # データセット変換
135
+ # ============================================================
136
+ def convert_glaive_to_chatml(example: dict) -> dict:
137
+ """
138
+ glaive-function-calling-v2形式をChatML形式に変換
139
+
140
+ 元データ形式:
141
+ - system: 関数定義を含むシステムプロンプト
142
+ - chat: "USER: ... ASSISTANT: ..." 形式の会話
143
+ """
144
+ parts = []
145
+
146
+ # システムプロンプト
147
+ if example.get("system"):
148
+ parts.append(f"<|im_start|>system\n{example['system']}<|im_end|>")
149
+
150
+ # 会話を解析
151
+ chat = example.get("chat", "")
152
+ if chat:
153
+ # "USER:" と "ASSISTANT:" で分割
154
+ # 複数ターンに対応
155
+ current_role = None
156
+ current_content = []
157
+
158
+ for line in chat.split("\n"):
159
+ line = line.strip()
160
+ if line.startswith("USER:"):
161
+ # 前のメッセージを保存
162
+ if current_role and current_content:
163
+ content = "\n".join(current_content).strip()
164
+ if content:
165
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
166
+ current_role = "user"
167
+ current_content = [line[5:].strip()] # "USER:" を除去
168
+ elif line.startswith("ASSISTANT:"):
169
+ # 前のメッセージを保存
170
+ if current_role and current_content:
171
+ content = "\n".join(current_content).strip()
172
+ if content:
173
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
174
+ current_role = "assistant"
175
+ current_content = [line[10:].strip()] # "ASSISTANT:" を除去
176
+ elif current_role:
177
+ current_content.append(line)
178
+
179
+ # 最後のメッセージを保存
180
+ if current_role and current_content:
181
+ content = "\n".join(current_content).strip()
182
+ if content:
183
+ parts.append(f"<|im_start|>{current_role}\n{content}<|im_end|>")
184
+
185
+ return {"text": "\n".join(parts)}
186
+
187
+
188
+ def load_and_prepare_dataset():
189
+ """データセットを読み込んで前処理"""
190
+ print(f"\n{'=' * 60}")
191
+ print(f"Loading dataset: {DATASET_NAME}")
192
+ print(f"{'=' * 60}")
193
+
194
+ # データセット読み込み
195
+ dataset = load_dataset(DATASET_NAME, split="train")
196
+ print(f"Original size: {len(dataset)} examples")
197
+
198
+ # 変換
199
+ print("Converting to ChatML format...")
200
+ dataset = dataset.map(
201
+ convert_glaive_to_chatml,
202
+ remove_columns=dataset.column_names,
203
+ num_proc=4,
204
+ desc="Converting"
205
+ )
206
+
207
+ # 空のサンプルをフィルタ
208
+ dataset = dataset.filter(lambda x: len(x["text"]) > 50)
209
+ print(f"After filtering: {len(dataset)} examples")
210
+
211
+ # サンプル表示
212
+ print("\n--- Sample data ---")
213
+ sample = dataset[0]["text"]
214
+ print(sample[:500] + "..." if len(sample) > 500 else sample)
215
+ print("--- End sample ---\n")
216
+
217
+ # シャッフルしてTrain/Test分割
218
+ dataset = dataset.shuffle(seed=42)
219
+ split = dataset.train_test_split(test_size=0.02, seed=42)
220
+
221
+ print(f"Train: {len(split['train'])} examples")
222
+ print(f"Test: {len(split['test'])} examples")
223
+
224
+ return split
225
+
226
+
227
+ # ============================================================
228
+ # 学習パラメータ
229
+ # ============================================================
230
+ training_args = TrainingArguments(
231
+ output_dir=CHECKPOINT_DIR,
232
+
233
+ # エポック・ステップ
234
+ num_train_epochs=2,
235
+ max_steps=-1, # -1 = エポックベース
236
+
237
+ # バッチサイズ (7Bは3Bより小さく)
238
+ per_device_train_batch_size=2,
239
+ per_device_eval_batch_size=2,
240
+ gradient_accumulation_steps=16, # 有効バッチサイズ: 2*16=32
241
+
242
+ # 学習率
243
+ learning_rate=1e-4,
244
+ weight_decay=0.01,
245
+ warmup_ratio=0.03,
246
+ lr_scheduler_type="cosine",
247
+
248
+ # 最適化
249
+ optim="paged_adamw_8bit",
250
+ fp16=False,
251
+ bf16=True,
252
+ max_grad_norm=0.3,
253
+
254
+ # ログ・保存(重要!)
255
+ logging_steps=10, # 10ステップごとにログ
256
+ save_steps=500, # 500ステップごとにチェックポイント
257
+ save_total_limit=3, # 最新3つのチェックポイントを保持
258
+ eval_strategy="steps",
259
+ eval_steps=500, # 500ステップごとに評価
260
+
261
+ # その他
262
+ report_to="none",
263
+ group_by_length=True,
264
+ gradient_checkpointing=True,
265
+
266
+ # 再開用
267
+ save_safetensors=True,
268
+ load_best_model_at_end=False,
269
+ )
270
+
271
+
272
+ # ============================================================
273
+ # メイン
274
+ # ============================================================
275
+ def main():
276
+ print("\n" + "=" * 70)
277
+ print(" Qwen2.5-7B + glaive-function-calling-v2 QLoRA Training")
278
+ print("=" * 70)
279
+ print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
280
+ print(f"Base model: {BASE_MODEL}")
281
+ print(f"Dataset: {DATASET_NAME}")
282
+ print(f"Output: {OUTPUT_MODEL_ID}")
283
+ print("=" * 70 + "\n")
284
+
285
+ # GPU確認
286
+ if torch.cuda.is_available():
287
+ gpu_name = torch.cuda.get_device_name(0)
288
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
289
+ print(f"GPU: {gpu_name}")
290
+ print(f"VRAM: {gpu_mem:.1f} GB")
291
+ else:
292
+ print("ERROR: No GPU available!")
293
+ sys.exit(1)
294
+
295
+ # データセット読み込み
296
+ dataset = load_and_prepare_dataset()
297
+
298
+ # トークナイザー読み込み
299
+ print(f"\nLoading tokenizer: {BASE_MODEL}")
300
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
301
+ tokenizer.padding_side = "right"
302
+ if tokenizer.pad_token is None:
303
+ tokenizer.pad_token = tokenizer.eos_token
304
+
305
+ # モデル読み込み (4bit量子化)
306
+ print(f"\nLoading model: {BASE_MODEL} (4-bit quantized)")
307
+ print("This may take a few minutes...")
308
+ model = AutoModelForCausalLM.from_pretrained(
309
+ BASE_MODEL,
310
+ quantization_config=bnb_config,
311
+ device_map="auto",
312
+ attn_implementation="sdpa",
313
+ trust_remote_code=True,
314
+ )
315
+
316
+ # 学習準備
317
+ print("\nPreparing model for training...")
318
+ model = prepare_model_for_kbit_training(model)
319
+ model = get_peft_model(model, lora_config)
320
+
321
+ print("\nTrainable parameters:")
322
+ model.print_trainable_parameters()
323
+
324
+ # SFTTrainer設定
325
+ trainer = SFTTrainer(
326
+ model=model,
327
+ train_dataset=dataset["train"],
328
+ eval_dataset=dataset["test"],
329
+ args=training_args,
330
+ peft_config=lora_config,
331
+ processing_class=tokenizer,
332
+ max_seq_length=2048, # 7Bなので少し長く
333
+ packing=True,
334
+ dataset_text_field="text",
335
+ callbacks=[VerboseLoggingCallback()],
336
+ )
337
+
338
+ # チェックポイントからの再開確認
339
+ resume_from = None
340
+ if os.path.exists(CHECKPOINT_DIR):
341
+ checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith("checkpoint-")]
342
+ if checkpoints:
343
+ latest = max(checkpoints, key=lambda x: int(x.split("-")[1]))
344
+ resume_from = os.path.join(CHECKPOINT_DIR, latest)
345
+ print(f"\n📂 Found checkpoint: {resume_from}")
346
+ print(" Resuming from checkpoint...")
347
+
348
+ # 学習実行
349
+ print("\n" + "=" * 70)
350
+ print("Starting training...")
351
+ print("=" * 70)
352
+
353
+ trainer.train(resume_from_checkpoint=resume_from)
354
+
355
+ # 最終モデル保存
356
+ print(f"\nSaving final model to {FINAL_OUTPUT_DIR}...")
357
+ trainer.save_model(FINAL_OUTPUT_DIR)
358
+ tokenizer.save_pretrained(FINAL_OUTPUT_DIR)
359
+
360
+ # HFにアップロード
361
+ print(f"\nUploading to HuggingFace: {OUTPUT_MODEL_ID}")
362
+ try:
363
+ trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True)
364
+ tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True)
365
+ print(f"✅ Model uploaded to: https://huggingface.co/{OUTPUT_MODEL_ID}")
366
+ except Exception as e:
367
+ print(f"⚠️ Upload failed: {e}")
368
+ print(" Model saved locally. Please upload manually.")
369
+
370
+ print("\n" + "=" * 70)
371
+ print("🎉 Training complete!")
372
+ print("=" * 70)
373
+
374
+
375
+ if __name__ == "__main__":
376
+ main()