glaive-7b-training / merge_and_quantize.py
Hajime MATSUMOTO
Add 7B QLoRA training + AWQ quantization scripts
6d8a316
#!/usr/bin/env python3
"""
LoRAマージ + AWQ量子化スクリプト
学習完了後に実行:
1. LoRAアダプターをベースモデルにマージ
2. AWQ量子化(4bit)
3. HuggingFaceにアップロード
"""
import os
import sys
import shutil
from datetime import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from awq import AutoAWQForCausalLM
# ============================================================
# 設定
# ============================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
LORA_MODEL = "hajimemat/qwen2.5-7b-glaive-fc-lora" # 学習済みLoRA
# 出力先
MERGED_MODEL_DIR = "./merged_model"
QUANTIZED_MODEL_DIR = "./quantized_model"
OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-awq"
# AWQ量子化設定
AWQ_CONFIG = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
}
def step1_merge_lora():
"""Step 1: LoRAをベースモデルにマージ"""
print("\n" + "=" * 60)
print("Step 1: Merging LoRA adapter to base model")
print("=" * 60)
print(f"Base model: {BASE_MODEL}")
print(f"LoRA model: {LORA_MODEL}")
# ベースモデル読み込み
print("\nLoading base model...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# トークナイザー読み込み
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# LoRAアダプター適用
print("Loading LoRA adapter...")
model = PeftModel.from_pretrained(base_model, LORA_MODEL)
# マージ
print("Merging LoRA weights...")
model = model.merge_and_unload()
# 保存
print(f"Saving merged model to {MERGED_MODEL_DIR}...")
model.save_pretrained(MERGED_MODEL_DIR, safe_serialization=True)
tokenizer.save_pretrained(MERGED_MODEL_DIR)
# メモリ解放
del model
del base_model
torch.cuda.empty_cache()
print("✅ Step 1 complete: LoRA merged")
return MERGED_MODEL_DIR
def step2_quantize_awq(merged_model_path):
"""Step 2: AWQ量子化"""
print("\n" + "=" * 60)
print("Step 2: AWQ Quantization (4-bit)")
print("=" * 60)
print(f"Input model: {merged_model_path}")
print(f"AWQ config: {AWQ_CONFIG}")
# モデル読み込み
print("\nLoading merged model for quantization...")
model = AutoAWQForCausalLM.from_pretrained(
merged_model_path,
trust_remote_code=True,
safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(merged_model_path)
# 量子化用キャリブレーションデータ
# シンプルなサンプルで十分
calib_data = [
"Hello, how are you today?",
"What is the weather like?",
"Can you help me with coding?",
"Search for files containing docker",
"List all repositories in the project",
"Find the definition of the function main",
"こんにちは、今日の天気はどうですか?",
"プロジェクトのファイル構成を教えてください",
]
# 量子化実行
print("\nQuantizing model (this may take a while)...")
model.quantize(
tokenizer,
quant_config=AWQ_CONFIG,
calib_data=calib_data,
)
# 保存
print(f"\nSaving quantized model to {QUANTIZED_MODEL_DIR}...")
model.save_quantized(QUANTIZED_MODEL_DIR, safetensors=True)
tokenizer.save_pretrained(QUANTIZED_MODEL_DIR)
print("✅ Step 2 complete: AWQ quantization done")
return QUANTIZED_MODEL_DIR
def step3_upload_to_hub(quantized_model_path):
"""Step 3: HuggingFaceにアップロード"""
print("\n" + "=" * 60)
print("Step 3: Upload to HuggingFace Hub")
print("=" * 60)
print(f"Uploading to: {OUTPUT_MODEL_ID}")
from huggingface_hub import HfApi, upload_folder
api = HfApi()
# リポジトリ作成(存在しなければ)
try:
api.create_repo(OUTPUT_MODEL_ID, private=True, exist_ok=True)
except Exception as e:
print(f"Note: {e}")
# アップロード
print("Uploading files...")
upload_folder(
folder_path=quantized_model_path,
repo_id=OUTPUT_MODEL_ID,
repo_type="model",
)
print(f"✅ Step 3 complete: Uploaded to https://huggingface.co/{OUTPUT_MODEL_ID}")
def main():
print("\n" + "=" * 70)
print(" LoRA Merge + AWQ Quantization Pipeline")
print("=" * 70)
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Base: {BASE_MODEL}")
print(f"LoRA: {LORA_MODEL}")
print(f"Output: {OUTPUT_MODEL_ID}")
print("=" * 70)
# GPU確認
if torch.cuda.is_available():
print(f"\nGPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("WARNING: No GPU available, this will be slow!")
# Step 1: マージ
merged_path = step1_merge_lora()
# Step 2: 量子化
quantized_path = step2_quantize_awq(merged_path)
# Step 3: アップロード
step3_upload_to_hub(quantized_path)
# クリーンアップ(オプション)
print("\n" + "=" * 60)
print("Cleanup")
print("=" * 60)
cleanup = input("Delete intermediate files? (merged_model/) [y/N]: ").strip().lower()
if cleanup == 'y':
shutil.rmtree(MERGED_MODEL_DIR, ignore_errors=True)
print("Cleaned up merged_model/")
print("\n" + "=" * 70)
print("🎉 Pipeline complete!")
print(f"Model available at: https://huggingface.co/{OUTPUT_MODEL_ID}")
print("=" * 70)
if __name__ == "__main__":
main()