ericflo commited on
Commit
ad0946a
Β·
verified Β·
1 Parent(s): 834460a

Upload training/merge_and_export.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/merge_and_export.py +228 -0
training/merge_and_export.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Merge LoRA weights into Qwen3-0.6B and export merged model for GGUF conversion.
4
+
5
+ 1. Load base Qwen3-0.6B
6
+ 2. Apply LoRA adapters
7
+ 3. Load trained LoRA weights from checkpoint
8
+ 4. Merge LoRA into base weights (W' = W + B*A*scaling)
9
+ 5. Save merged model in HuggingFace format
10
+ 6. Convert to GGUF using llama.cpp's converter
11
+
12
+ Usage:
13
+ python3 merge_and_export.py --checkpoint /workspace/output/best_distill.pt --output-dir /workspace/merged
14
+ """
15
+ import argparse
16
+ import json
17
+ import math
18
+ import os
19
+ import sys
20
+ import time
21
+
22
+ sys.stdout.reconfigure(line_buffering=True)
23
+
24
+
25
+ def log(msg):
26
+ print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
27
+
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--checkpoint", required=True, help="Path to best_distill.pt")
32
+ parser.add_argument("--output-dir", default="/workspace/merged")
33
+ parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B")
34
+ parser.add_argument("--gguf-output", default="/workspace/merged/qwen3-0.6b-summarizer.gguf")
35
+ args = parser.parse_args()
36
+
37
+ # Auto-install deps
38
+ import subprocess as _sp
39
+ for pkg in ["numpy", "transformers", "accelerate", "safetensors"]:
40
+ try:
41
+ __import__(pkg)
42
+ except ImportError:
43
+ log(f"Installing {pkg}...")
44
+ _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", pkg], check=True)
45
+
46
+ import torch
47
+ import torch.nn as nn
48
+ from transformers import AutoTokenizer, AutoModelForCausalLM
49
+
50
+ log(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
51
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ os.makedirs(args.output_dir, exist_ok=True)
53
+
54
+ # ── Load checkpoint ────────────────────────────────────────────────
55
+ log(f"Loading checkpoint: {args.checkpoint}")
56
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
57
+ config = ckpt.get("config", {})
58
+ lora_rank = config.get("lora_rank", 16)
59
+ lora_alpha = config.get("lora_alpha", 32)
60
+ scaling = lora_alpha / lora_rank
61
+ log(f"LoRA rank={lora_rank} alpha={lora_alpha} scaling={scaling}")
62
+
63
+ # ── Load base model ────────────────────────────────────────────────
64
+ log(f"Loading base model: {args.model_name}")
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ args.model_name, torch_dtype=torch.float32, trust_remote_code=True,
67
+ )
68
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
69
+ log(f"Model loaded")
70
+
71
+ # ── Merge LoRA weights ─────────────────────────────────────────────
72
+ log("Merging LoRA weights into base model...")
73
+ lora_state = ckpt["lora_state"]
74
+ n_merged = 0
75
+
76
+ for name, module in model.named_modules():
77
+ for proj_name in ["q_proj", "v_proj"]:
78
+ if not hasattr(module, proj_name):
79
+ continue
80
+ proj = getattr(module, proj_name)
81
+ if not isinstance(proj, nn.Linear):
82
+ continue
83
+
84
+ # Find matching LoRA weights
85
+ # The key format from training: "model.layers.N.self_attn.q_proj.lora_A"
86
+ lora_key_a = None
87
+ for k in lora_state:
88
+ if proj_name in k and "lora_A" in k:
89
+ # Match by layer path
90
+ full_path = f"{name}.{proj_name}"
91
+ lora_path = k.replace(".lora_A", "").replace(".lora_B", "")
92
+ if full_path in lora_path or lora_path in full_path:
93
+ lora_key_a = k
94
+ break
95
+
96
+ if lora_key_a is None:
97
+ # Try simpler matching
98
+ for k in lora_state:
99
+ if f"{name}.{proj_name}" in k and "lora_A" in k:
100
+ lora_key_a = k
101
+ break
102
+
103
+ if lora_key_a is None:
104
+ continue
105
+
106
+ lora_key_b = lora_key_a.replace("lora_A", "lora_B")
107
+ if lora_key_b not in lora_state:
108
+ continue
109
+
110
+ A_weight = lora_state[lora_key_a]["weight"].float() # (rank, in_features)
111
+ B_weight = lora_state[lora_key_b]["weight"].float() # (out_features, rank)
112
+
113
+ # Merge: W' = W + B @ A * scaling
114
+ delta = (B_weight @ A_weight) * scaling
115
+ proj.weight.data += delta.to(proj.weight.dtype)
116
+ n_merged += 1
117
+
118
+ log(f"Merged {n_merged} LoRA layers into base weights")
119
+
120
+ if n_merged == 0:
121
+ log("WARNING: No LoRA layers merged! Trying alternative key matching...")
122
+ log(f"Available LoRA keys: {list(lora_state.keys())[:10]}")
123
+ # Try matching by index
124
+ lora_pairs = {}
125
+ for k, v in lora_state.items():
126
+ base_key = k.replace(".lora_A", "").replace(".lora_B", "")
127
+ if base_key not in lora_pairs:
128
+ lora_pairs[base_key] = {}
129
+ if "lora_A" in k:
130
+ lora_pairs[base_key]["A"] = v
131
+ elif "lora_B" in k:
132
+ lora_pairs[base_key]["B"] = v
133
+
134
+ # Collect all q_proj and v_proj layers in order
135
+ target_layers = []
136
+ for name, module in model.named_modules():
137
+ for proj_name in ["q_proj", "v_proj"]:
138
+ if hasattr(module, proj_name):
139
+ proj = getattr(module, proj_name)
140
+ if isinstance(proj, nn.Linear):
141
+ target_layers.append((name, proj_name, proj))
142
+
143
+ # Sort LoRA pairs by key and match by index
144
+ sorted_pairs = sorted(lora_pairs.items())
145
+ log(f"Found {len(sorted_pairs)} LoRA pairs, {len(target_layers)} target layers")
146
+
147
+ for (lora_key, pair), (name, proj_name, proj) in zip(sorted_pairs, target_layers):
148
+ if "A" in pair and "B" in pair:
149
+ A_weight = pair["A"]["weight"].float()
150
+ B_weight = pair["B"]["weight"].float()
151
+ delta = (B_weight @ A_weight) * scaling
152
+ proj.weight.data += delta.to(proj.weight.dtype)
153
+ n_merged += 1
154
+
155
+ log(f"Merged {n_merged} LoRA layers (index matching)")
156
+
157
+ # ── Save merged model ──────────────────────────────────────────────
158
+ log(f"Saving merged model to {args.output_dir}")
159
+ model.save_pretrained(args.output_dir)
160
+ tokenizer.save_pretrained(args.output_dir)
161
+ log(f"Merged model saved ({sum(f.stat().st_size for f in __import__('pathlib').Path(args.output_dir).rglob('*') if f.is_file()) / 1024**2:.0f} MB)")
162
+
163
+ # ── Convert to GGUF ────────────────────────────────────────────────
164
+ log("Converting to GGUF (Q8_0)...")
165
+ try:
166
+ # Install llama.cpp converter
167
+ _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", "gguf"], check=True)
168
+
169
+ # Try using the HF converter
170
+ result = _sp.run([
171
+ sys.executable, "-m", "transformers", "gguf-export",
172
+ "--model", args.output_dir,
173
+ "--output", args.gguf_output,
174
+ "--quantize", "q8_0",
175
+ ], capture_output=True, text=True, timeout=300)
176
+
177
+ if result.returncode != 0:
178
+ log(f"transformers gguf-export failed: {result.stderr[:200]}")
179
+ # Fallback: use llama.cpp's convert script
180
+ log("Trying llama.cpp converter...")
181
+ _sp.run(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git",
182
+ "/tmp/llama.cpp"], capture_output=True, timeout=120)
183
+ _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q",
184
+ "-r", "/tmp/llama.cpp/requirements.txt"], capture_output=True, timeout=120)
185
+
186
+ # Convert HF β†’ GGUF F16 first
187
+ gguf_f16 = args.gguf_output.replace(".gguf", "-f16.gguf")
188
+ result = _sp.run([
189
+ sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py",
190
+ args.output_dir,
191
+ "--outfile", gguf_f16,
192
+ "--outtype", "f16",
193
+ ], capture_output=True, text=True, timeout=300)
194
+ if result.returncode == 0:
195
+ log(f"GGUF F16 created: {gguf_f16}")
196
+ # Quantize to Q8_0
197
+ q8_result = _sp.run([
198
+ "/tmp/llama.cpp/build/bin/llama-quantize" if os.path.exists("/tmp/llama.cpp/build/bin/llama-quantize") else "echo",
199
+ gguf_f16, args.gguf_output, "q8_0"
200
+ ], capture_output=True, text=True, timeout=300)
201
+ if q8_result.returncode == 0:
202
+ log(f"GGUF Q8_0 created: {args.gguf_output}")
203
+ else:
204
+ log(f"Quantization failed, using F16: {gguf_f16}")
205
+ args.gguf_output = gguf_f16
206
+ else:
207
+ log(f"GGUF conversion failed: {result.stderr[:300]}")
208
+ else:
209
+ log(f"GGUF created: {args.gguf_output}")
210
+
211
+ except Exception as e:
212
+ log(f"GGUF conversion error: {e}")
213
+
214
+ # List outputs
215
+ log("")
216
+ log("Output files:")
217
+ for f in sorted(os.listdir(args.output_dir)):
218
+ path = os.path.join(args.output_dir, f)
219
+ if os.path.isfile(path):
220
+ size = os.path.getsize(path)
221
+ log(f" {f}: {size/1024**2:.1f} MB")
222
+
223
+ log("")
224
+ log("DONE")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()