qwen3-0.6b-summarizer / training /merge_and_export.py
ericflo's picture
Upload training/merge_and_export.py with huggingface_hub
ad0946a verified
#!/usr/bin/env python3
"""
Merge LoRA weights into Qwen3-0.6B and export merged model for GGUF conversion.
1. Load base Qwen3-0.6B
2. Apply LoRA adapters
3. Load trained LoRA weights from checkpoint
4. Merge LoRA into base weights (W' = W + B*A*scaling)
5. Save merged model in HuggingFace format
6. Convert to GGUF using llama.cpp's converter
Usage:
python3 merge_and_export.py --checkpoint /workspace/output/best_distill.pt --output-dir /workspace/merged
"""
import argparse
import json
import math
import os
import sys
import time
sys.stdout.reconfigure(line_buffering=True)
def log(msg):
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True, help="Path to best_distill.pt")
parser.add_argument("--output-dir", default="/workspace/merged")
parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B")
parser.add_argument("--gguf-output", default="/workspace/merged/qwen3-0.6b-summarizer.gguf")
args = parser.parse_args()
# Auto-install deps
import subprocess as _sp
for pkg in ["numpy", "transformers", "accelerate", "safetensors"]:
try:
__import__(pkg)
except ImportError:
log(f"Installing {pkg}...")
_sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", pkg], check=True)
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
log(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(args.output_dir, exist_ok=True)
# ── Load checkpoint ────────────────────────────────────────────────
log(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
config = ckpt.get("config", {})
lora_rank = config.get("lora_rank", 16)
lora_alpha = config.get("lora_alpha", 32)
scaling = lora_alpha / lora_rank
log(f"LoRA rank={lora_rank} alpha={lora_alpha} scaling={scaling}")
# ── Load base model ────────────────────────────────────────────────
log(f"Loading base model: {args.model_name}")
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float32, trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
log(f"Model loaded")
# ── Merge LoRA weights ─────────────────────────────────────────────
log("Merging LoRA weights into base model...")
lora_state = ckpt["lora_state"]
n_merged = 0
for name, module in model.named_modules():
for proj_name in ["q_proj", "v_proj"]:
if not hasattr(module, proj_name):
continue
proj = getattr(module, proj_name)
if not isinstance(proj, nn.Linear):
continue
# Find matching LoRA weights
# The key format from training: "model.layers.N.self_attn.q_proj.lora_A"
lora_key_a = None
for k in lora_state:
if proj_name in k and "lora_A" in k:
# Match by layer path
full_path = f"{name}.{proj_name}"
lora_path = k.replace(".lora_A", "").replace(".lora_B", "")
if full_path in lora_path or lora_path in full_path:
lora_key_a = k
break
if lora_key_a is None:
# Try simpler matching
for k in lora_state:
if f"{name}.{proj_name}" in k and "lora_A" in k:
lora_key_a = k
break
if lora_key_a is None:
continue
lora_key_b = lora_key_a.replace("lora_A", "lora_B")
if lora_key_b not in lora_state:
continue
A_weight = lora_state[lora_key_a]["weight"].float() # (rank, in_features)
B_weight = lora_state[lora_key_b]["weight"].float() # (out_features, rank)
# Merge: W' = W + B @ A * scaling
delta = (B_weight @ A_weight) * scaling
proj.weight.data += delta.to(proj.weight.dtype)
n_merged += 1
log(f"Merged {n_merged} LoRA layers into base weights")
if n_merged == 0:
log("WARNING: No LoRA layers merged! Trying alternative key matching...")
log(f"Available LoRA keys: {list(lora_state.keys())[:10]}")
# Try matching by index
lora_pairs = {}
for k, v in lora_state.items():
base_key = k.replace(".lora_A", "").replace(".lora_B", "")
if base_key not in lora_pairs:
lora_pairs[base_key] = {}
if "lora_A" in k:
lora_pairs[base_key]["A"] = v
elif "lora_B" in k:
lora_pairs[base_key]["B"] = v
# Collect all q_proj and v_proj layers in order
target_layers = []
for name, module in model.named_modules():
for proj_name in ["q_proj", "v_proj"]:
if hasattr(module, proj_name):
proj = getattr(module, proj_name)
if isinstance(proj, nn.Linear):
target_layers.append((name, proj_name, proj))
# Sort LoRA pairs by key and match by index
sorted_pairs = sorted(lora_pairs.items())
log(f"Found {len(sorted_pairs)} LoRA pairs, {len(target_layers)} target layers")
for (lora_key, pair), (name, proj_name, proj) in zip(sorted_pairs, target_layers):
if "A" in pair and "B" in pair:
A_weight = pair["A"]["weight"].float()
B_weight = pair["B"]["weight"].float()
delta = (B_weight @ A_weight) * scaling
proj.weight.data += delta.to(proj.weight.dtype)
n_merged += 1
log(f"Merged {n_merged} LoRA layers (index matching)")
# ── Save merged model ──────────────────────────────────────────────
log(f"Saving merged model to {args.output_dir}")
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
log(f"Merged model saved ({sum(f.stat().st_size for f in __import__('pathlib').Path(args.output_dir).rglob('*') if f.is_file()) / 1024**2:.0f} MB)")
# ── Convert to GGUF ────────────────────────────────────────────────
log("Converting to GGUF (Q8_0)...")
try:
# Install llama.cpp converter
_sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", "gguf"], check=True)
# Try using the HF converter
result = _sp.run([
sys.executable, "-m", "transformers", "gguf-export",
"--model", args.output_dir,
"--output", args.gguf_output,
"--quantize", "q8_0",
], capture_output=True, text=True, timeout=300)
if result.returncode != 0:
log(f"transformers gguf-export failed: {result.stderr[:200]}")
# Fallback: use llama.cpp's convert script
log("Trying llama.cpp converter...")
_sp.run(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git",
"/tmp/llama.cpp"], capture_output=True, timeout=120)
_sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q",
"-r", "/tmp/llama.cpp/requirements.txt"], capture_output=True, timeout=120)
# Convert HF β†’ GGUF F16 first
gguf_f16 = args.gguf_output.replace(".gguf", "-f16.gguf")
result = _sp.run([
sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py",
args.output_dir,
"--outfile", gguf_f16,
"--outtype", "f16",
], capture_output=True, text=True, timeout=300)
if result.returncode == 0:
log(f"GGUF F16 created: {gguf_f16}")
# Quantize to Q8_0
q8_result = _sp.run([
"/tmp/llama.cpp/build/bin/llama-quantize" if os.path.exists("/tmp/llama.cpp/build/bin/llama-quantize") else "echo",
gguf_f16, args.gguf_output, "q8_0"
], capture_output=True, text=True, timeout=300)
if q8_result.returncode == 0:
log(f"GGUF Q8_0 created: {args.gguf_output}")
else:
log(f"Quantization failed, using F16: {gguf_f16}")
args.gguf_output = gguf_f16
else:
log(f"GGUF conversion failed: {result.stderr[:300]}")
else:
log(f"GGUF created: {args.gguf_output}")
except Exception as e:
log(f"GGUF conversion error: {e}")
# List outputs
log("")
log("Output files:")
for f in sorted(os.listdir(args.output_dir)):
path = os.path.join(args.output_dir, f)
if os.path.isfile(path):
size = os.path.getsize(path)
log(f" {f}: {size/1024**2:.1f} MB")
log("")
log("DONE")
if __name__ == "__main__":
main()