File size: 5,816 Bytes
6d8a316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
"""
LoRAマージ + AWQ量子化スクリプト

学習完了後に実行:
1. LoRAアダプターをベースモデルにマージ
2. AWQ量子化(4bit)
3. HuggingFaceにアップロード
"""

import os
import sys
import shutil
from datetime import datetime

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from awq import AutoAWQForCausalLM

# ============================================================
# 設定
# ============================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct"
LORA_MODEL = "hajimemat/qwen2.5-7b-glaive-fc-lora"  # 学習済みLoRA

# 出力先
MERGED_MODEL_DIR = "./merged_model"
QUANTIZED_MODEL_DIR = "./quantized_model"
OUTPUT_MODEL_ID = "hajimemat/qwen2.5-7b-glaive-fc-awq"

# AWQ量子化設定
AWQ_CONFIG = {
    "zero_point": True,
    "q_group_size": 128,
    "w_bit": 4,
    "version": "GEMM"
}


def step1_merge_lora():
    """Step 1: LoRAをベースモデルにマージ"""
    print("\n" + "=" * 60)
    print("Step 1: Merging LoRA adapter to base model")
    print("=" * 60)

    print(f"Base model: {BASE_MODEL}")
    print(f"LoRA model: {LORA_MODEL}")

    # ベースモデル読み込み
    print("\nLoading base model...")
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    # トークナイザー読み込み
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

    # LoRAアダプター適用
    print("Loading LoRA adapter...")
    model = PeftModel.from_pretrained(base_model, LORA_MODEL)

    # マージ
    print("Merging LoRA weights...")
    model = model.merge_and_unload()

    # 保存
    print(f"Saving merged model to {MERGED_MODEL_DIR}...")
    model.save_pretrained(MERGED_MODEL_DIR, safe_serialization=True)
    tokenizer.save_pretrained(MERGED_MODEL_DIR)

    # メモリ解放
    del model
    del base_model
    torch.cuda.empty_cache()

    print("✅ Step 1 complete: LoRA merged")
    return MERGED_MODEL_DIR


def step2_quantize_awq(merged_model_path):
    """Step 2: AWQ量子化"""
    print("\n" + "=" * 60)
    print("Step 2: AWQ Quantization (4-bit)")
    print("=" * 60)

    print(f"Input model: {merged_model_path}")
    print(f"AWQ config: {AWQ_CONFIG}")

    # モデル読み込み
    print("\nLoading merged model for quantization...")
    model = AutoAWQForCausalLM.from_pretrained(
        merged_model_path,
        trust_remote_code=True,
        safetensors=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(merged_model_path)

    # 量子化用キャリブレーションデータ
    # シンプルなサンプルで十分
    calib_data = [
        "Hello, how are you today?",
        "What is the weather like?",
        "Can you help me with coding?",
        "Search for files containing docker",
        "List all repositories in the project",
        "Find the definition of the function main",
        "こんにちは、今日の天気はどうですか?",
        "プロジェクトのファイル構成を教えてください",
    ]

    # 量子化実行
    print("\nQuantizing model (this may take a while)...")
    model.quantize(
        tokenizer,
        quant_config=AWQ_CONFIG,
        calib_data=calib_data,
    )

    # 保存
    print(f"\nSaving quantized model to {QUANTIZED_MODEL_DIR}...")
    model.save_quantized(QUANTIZED_MODEL_DIR, safetensors=True)
    tokenizer.save_pretrained(QUANTIZED_MODEL_DIR)

    print("✅ Step 2 complete: AWQ quantization done")
    return QUANTIZED_MODEL_DIR


def step3_upload_to_hub(quantized_model_path):
    """Step 3: HuggingFaceにアップロード"""
    print("\n" + "=" * 60)
    print("Step 3: Upload to HuggingFace Hub")
    print("=" * 60)

    print(f"Uploading to: {OUTPUT_MODEL_ID}")

    from huggingface_hub import HfApi, upload_folder

    api = HfApi()

    # リポジトリ作成(存在しなければ)
    try:
        api.create_repo(OUTPUT_MODEL_ID, private=True, exist_ok=True)
    except Exception as e:
        print(f"Note: {e}")

    # アップロード
    print("Uploading files...")
    upload_folder(
        folder_path=quantized_model_path,
        repo_id=OUTPUT_MODEL_ID,
        repo_type="model",
    )

    print(f"✅ Step 3 complete: Uploaded to https://huggingface.co/{OUTPUT_MODEL_ID}")


def main():
    print("\n" + "=" * 70)
    print("  LoRA Merge + AWQ Quantization Pipeline")
    print("=" * 70)
    print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Base: {BASE_MODEL}")
    print(f"LoRA: {LORA_MODEL}")
    print(f"Output: {OUTPUT_MODEL_ID}")
    print("=" * 70)

    # GPU確認
    if torch.cuda.is_available():
        print(f"\nGPU: {torch.cuda.get_device_name(0)}")
        print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print("WARNING: No GPU available, this will be slow!")

    # Step 1: マージ
    merged_path = step1_merge_lora()

    # Step 2: 量子化
    quantized_path = step2_quantize_awq(merged_path)

    # Step 3: アップロード
    step3_upload_to_hub(quantized_path)

    # クリーンアップ(オプション)
    print("\n" + "=" * 60)
    print("Cleanup")
    print("=" * 60)
    cleanup = input("Delete intermediate files? (merged_model/) [y/N]: ").strip().lower()
    if cleanup == 'y':
        shutil.rmtree(MERGED_MODEL_DIR, ignore_errors=True)
        print("Cleaned up merged_model/")

    print("\n" + "=" * 70)
    print("🎉 Pipeline complete!")
    print(f"Model available at: https://huggingface.co/{OUTPUT_MODEL_ID}")
    print("=" * 70)


if __name__ == "__main__":
    main()