Hajime MATSUMOTO commited on
Commit
496eb13
·
1 Parent(s): 113833d

Add Colab training notebook

Browse files
Files changed (1) hide show
  1. colab_training.ipynb +344 -0
colab_training.ipynb ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Qwen2.5-7B QLoRA Training on Colab\n",
8
+ "\n",
9
+ "Google Colab Pro (A100) での学習用ノートブック\n",
10
+ "\n",
11
+ "**推奨**: Colab Pro ($10/月) 以上、A100 GPU"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "markdown",
16
+ "metadata": {},
17
+ "source": [
18
+ "## 1. 環境セットアップ"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "# GPU確認\n",
28
+ "!nvidia-smi"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "# Google Driveマウント(チェックポイント保存用)\n",
38
+ "from google.colab import drive\n",
39
+ "drive.mount('/content/drive')\n",
40
+ "\n",
41
+ "# 作業ディレクトリ作成\n",
42
+ "!mkdir -p /content/drive/MyDrive/qwen-training/checkpoints\n",
43
+ "!mkdir -p /content/drive/MyDrive/qwen-training/output"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "# 依存関係インストール\n",
53
+ "!pip install -q torch==2.2.0 torchvision==0.17.0\n",
54
+ "!pip install -q transformers==4.46.0 datasets peft==0.13.0 trl==0.11.0\n",
55
+ "!pip install -q bitsandbytes accelerate huggingface_hub safetensors"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# HuggingFaceログイン\n",
65
+ "from huggingface_hub import login\n",
66
+ "login() # トークンを入力"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "metadata": {},
72
+ "source": [
73
+ "## 2. 設定"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# 設定\n",
83
+ "BASE_MODEL = \"Qwen/Qwen2.5-7B-Instruct\"\n",
84
+ "OUTPUT_MODEL_ID = \"hajimemat/qwen2.5-7b-glaive-fc-lora-colab\" # 変更可\n",
85
+ "DATASET_NAME = \"glaiveai/glaive-function-calling-v2\"\n",
86
+ "\n",
87
+ "# Google Driveに保存\n",
88
+ "CHECKPOINT_DIR = \"/content/drive/MyDrive/qwen-training/checkpoints\"\n",
89
+ "FINAL_OUTPUT_DIR = \"/content/drive/MyDrive/qwen-training/output\""
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## 3. データセット準備"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "from datasets import load_dataset\n",
106
+ "\n",
107
+ "def convert_glaive_to_chatml(example):\n",
108
+ " parts = []\n",
109
+ " if example.get(\"system\"):\n",
110
+ " parts.append(f\"<|im_start|>system\\n{example['system']}<|im_end|>\")\n",
111
+ " \n",
112
+ " chat = example.get(\"chat\", \"\")\n",
113
+ " if chat:\n",
114
+ " current_role = None\n",
115
+ " current_content = []\n",
116
+ " for line in chat.split(\"\\n\"):\n",
117
+ " line = line.strip()\n",
118
+ " if line.startswith(\"USER:\"):\n",
119
+ " if current_role and current_content:\n",
120
+ " content = \"\\n\".join(current_content).strip()\n",
121
+ " if content:\n",
122
+ " parts.append(f\"<|im_start|>{current_role}\\n{content}<|im_end|>\")\n",
123
+ " current_role = \"user\"\n",
124
+ " current_content = [line[5:].strip()]\n",
125
+ " elif line.startswith(\"ASSISTANT:\"):\n",
126
+ " if current_role and current_content:\n",
127
+ " content = \"\\n\".join(current_content).strip()\n",
128
+ " if content:\n",
129
+ " parts.append(f\"<|im_start|>{current_role}\\n{content}<|im_end|>\")\n",
130
+ " current_role = \"assistant\"\n",
131
+ " current_content = [line[10:].strip()]\n",
132
+ " elif current_role:\n",
133
+ " current_content.append(line)\n",
134
+ " if current_role and current_content:\n",
135
+ " content = \"\\n\".join(current_content).strip()\n",
136
+ " if content:\n",
137
+ " parts.append(f\"<|im_start|>{current_role}\\n{content}<|im_end|>\")\n",
138
+ " return {\"text\": \"\\n\".join(parts)}\n",
139
+ "\n",
140
+ "print(f\"Loading dataset: {DATASET_NAME}\")\n",
141
+ "dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
142
+ "print(f\"Original: {len(dataset)} examples\")\n",
143
+ "\n",
144
+ "dataset = dataset.map(convert_glaive_to_chatml, remove_columns=dataset.column_names, num_proc=4)\n",
145
+ "dataset = dataset.filter(lambda x: len(x[\"text\"]) > 50)\n",
146
+ "dataset = dataset.shuffle(seed=42)\n",
147
+ "split = dataset.train_test_split(test_size=0.02, seed=42)\n",
148
+ "\n",
149
+ "print(f\"Train: {len(split['train'])}, Test: {len(split['test'])}\")"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 4. モデル準備"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "import torch\n",
166
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments\n",
167
+ "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
168
+ "\n",
169
+ "# QLoRA量子化設定\n",
170
+ "bnb_config = BitsAndBytesConfig(\n",
171
+ " load_in_4bit=True,\n",
172
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
173
+ " bnb_4bit_quant_type=\"nf4\",\n",
174
+ " bnb_4bit_use_double_quant=True,\n",
175
+ ")\n",
176
+ "\n",
177
+ "# LoRA設定\n",
178
+ "lora_config = LoraConfig(\n",
179
+ " r=64,\n",
180
+ " lora_alpha=16,\n",
181
+ " lora_dropout=0.05,\n",
182
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
183
+ " bias=\"none\",\n",
184
+ " task_type=\"CAUSAL_LM\",\n",
185
+ ")\n",
186
+ "\n",
187
+ "# トークナイザー\n",
188
+ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
189
+ "tokenizer.padding_side = \"right\"\n",
190
+ "if tokenizer.pad_token is None:\n",
191
+ " tokenizer.pad_token = tokenizer.eos_token\n",
192
+ "\n",
193
+ "# モデル\n",
194
+ "print(f\"Loading model: {BASE_MODEL}\")\n",
195
+ "model = AutoModelForCausalLM.from_pretrained(\n",
196
+ " BASE_MODEL,\n",
197
+ " quantization_config=bnb_config,\n",
198
+ " device_map=\"auto\",\n",
199
+ " attn_implementation=\"sdpa\",\n",
200
+ " trust_remote_code=True,\n",
201
+ ")\n",
202
+ "\n",
203
+ "model = prepare_model_for_kbit_training(model)\n",
204
+ "model = get_peft_model(model, lora_config)\n",
205
+ "model.print_trainable_parameters()"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "markdown",
210
+ "metadata": {},
211
+ "source": [
212
+ "## 5. 学習実行"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "from trl import SFTTrainer\n",
222
+ "\n",
223
+ "training_args = TrainingArguments(\n",
224
+ " output_dir=CHECKPOINT_DIR,\n",
225
+ " num_train_epochs=1,\n",
226
+ " per_device_train_batch_size=4,\n",
227
+ " per_device_eval_batch_size=4,\n",
228
+ " gradient_accumulation_steps=4,\n",
229
+ " learning_rate=2e-4,\n",
230
+ " weight_decay=0.01,\n",
231
+ " warmup_ratio=0.03,\n",
232
+ " lr_scheduler_type=\"cosine\",\n",
233
+ " optim=\"paged_adamw_8bit\",\n",
234
+ " bf16=True,\n",
235
+ " logging_steps=10,\n",
236
+ " save_steps=200,\n",
237
+ " save_total_limit=3,\n",
238
+ " eval_strategy=\"steps\",\n",
239
+ " eval_steps=200,\n",
240
+ " report_to=\"none\",\n",
241
+ " gradient_checkpointing=True,\n",
242
+ " save_safetensors=True,\n",
243
+ ")\n",
244
+ "\n",
245
+ "trainer = SFTTrainer(\n",
246
+ " model=model,\n",
247
+ " train_dataset=split[\"train\"],\n",
248
+ " eval_dataset=split[\"test\"],\n",
249
+ " args=training_args,\n",
250
+ " peft_config=lora_config,\n",
251
+ " tokenizer=tokenizer,\n",
252
+ " max_seq_length=1024,\n",
253
+ " packing=False,\n",
254
+ " dataset_text_field=\"text\",\n",
255
+ ")\n",
256
+ "\n",
257
+ "# チェックポイントから再開\n",
258
+ "import os\n",
259
+ "resume_from = None\n",
260
+ "if os.path.exists(CHECKPOINT_DIR):\n",
261
+ " checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith(\"checkpoint-\")]\n",
262
+ " if checkpoints:\n",
263
+ " latest = max(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n",
264
+ " resume_from = os.path.join(CHECKPOINT_DIR, latest)\n",
265
+ " print(f\"Resuming from: {resume_from}\")\n",
266
+ "\n",
267
+ "# 学習開始\n",
268
+ "trainer.train(resume_from_checkpoint=resume_from)"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "markdown",
273
+ "metadata": {},
274
+ "source": [
275
+ "## 6. 保存とアップロード"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "# ローカル保存\n",
285
+ "print(f\"Saving to {FINAL_OUTPUT_DIR}\")\n",
286
+ "trainer.save_model(FINAL_OUTPUT_DIR)\n",
287
+ "tokenizer.save_pretrained(FINAL_OUTPUT_DIR)\n",
288
+ "\n",
289
+ "# HuggingFaceにアップロード\n",
290
+ "print(f\"Uploading to {OUTPUT_MODEL_ID}\")\n",
291
+ "try:\n",
292
+ " trainer.model.push_to_hub(OUTPUT_MODEL_ID, private=True)\n",
293
+ " tokenizer.push_to_hub(OUTPUT_MODEL_ID, private=True)\n",
294
+ " print(f\"Done! https://huggingface.co/{OUTPUT_MODEL_ID}\")\n",
295
+ "except Exception as e:\n",
296
+ " print(f\"Upload failed: {e}\")\n",
297
+ " print(\"Model saved locally in Google Drive\")"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "metadata": {},
303
+ "source": [
304
+ "## 7. クイックテスト(オプション)"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": [
313
+ "# 簡単な推論テスト\n",
314
+ "from peft import PeftModel\n",
315
+ "\n",
316
+ "test_prompt = \"\"\"<|im_start|>system\n",
317
+ "You are a helpful assistant with access to functions.\n",
318
+ "<|im_end|>\n",
319
+ "<|im_start|>user\n",
320
+ "What's the weather in Tokyo?\n",
321
+ "<|im_end|>\n",
322
+ "<|im_start|>assistant\n",
323
+ "\"\"\"\n",
324
+ "\n",
325
+ "inputs = tokenizer(test_prompt, return_tensors=\"pt\").to(model.device)\n",
326
+ "outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.7)\n",
327
+ "print(tokenizer.decode(outputs[0], skip_special_tokens=False))"
328
+ ]
329
+ }
330
+ ],
331
+ "metadata": {
332
+ "accelerator": "GPU",
333
+ "colab": {
334
+ "gpuType": "A100",
335
+ "provenance": []
336
+ },
337
+ "kernelspec": {
338
+ "display_name": "Python 3",
339
+ "name": "python3"
340
+ }
341
+ },
342
+ "nbformat": 4,
343
+ "nbformat_minor": 0
344
+ }