| |
| """ |
| Merge LoRA weights into Qwen3-0.6B and export merged model for GGUF conversion. |
| |
| 1. Load base Qwen3-0.6B |
| 2. Apply LoRA adapters |
| 3. Load trained LoRA weights from checkpoint |
| 4. Merge LoRA into base weights (W' = W + B*A*scaling) |
| 5. Save merged model in HuggingFace format |
| 6. Convert to GGUF using llama.cpp's converter |
| |
| Usage: |
| python3 merge_and_export.py --checkpoint /workspace/output/best_distill.pt --output-dir /workspace/merged |
| """ |
| import argparse |
| import json |
| import math |
| import os |
| import sys |
| import time |
|
|
| sys.stdout.reconfigure(line_buffering=True) |
|
|
|
|
| def log(msg): |
| print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True, help="Path to best_distill.pt") |
| parser.add_argument("--output-dir", default="/workspace/merged") |
| parser.add_argument("--model-name", default="Qwen/Qwen3-0.6B") |
| parser.add_argument("--gguf-output", default="/workspace/merged/qwen3-0.6b-summarizer.gguf") |
| args = parser.parse_args() |
|
|
| |
| import subprocess as _sp |
| for pkg in ["numpy", "transformers", "accelerate", "safetensors"]: |
| try: |
| __import__(pkg) |
| except ImportError: |
| log(f"Installing {pkg}...") |
| _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", pkg], check=True) |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| log(f"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| log(f"Loading checkpoint: {args.checkpoint}") |
| ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False) |
| config = ckpt.get("config", {}) |
| lora_rank = config.get("lora_rank", 16) |
| lora_alpha = config.get("lora_alpha", 32) |
| scaling = lora_alpha / lora_rank |
| log(f"LoRA rank={lora_rank} alpha={lora_alpha} scaling={scaling}") |
|
|
| |
| log(f"Loading base model: {args.model_name}") |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_name, torch_dtype=torch.float32, trust_remote_code=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) |
| log(f"Model loaded") |
|
|
| |
| log("Merging LoRA weights into base model...") |
| lora_state = ckpt["lora_state"] |
| n_merged = 0 |
|
|
| for name, module in model.named_modules(): |
| for proj_name in ["q_proj", "v_proj"]: |
| if not hasattr(module, proj_name): |
| continue |
| proj = getattr(module, proj_name) |
| if not isinstance(proj, nn.Linear): |
| continue |
|
|
| |
| |
| lora_key_a = None |
| for k in lora_state: |
| if proj_name in k and "lora_A" in k: |
| |
| full_path = f"{name}.{proj_name}" |
| lora_path = k.replace(".lora_A", "").replace(".lora_B", "") |
| if full_path in lora_path or lora_path in full_path: |
| lora_key_a = k |
| break |
|
|
| if lora_key_a is None: |
| |
| for k in lora_state: |
| if f"{name}.{proj_name}" in k and "lora_A" in k: |
| lora_key_a = k |
| break |
|
|
| if lora_key_a is None: |
| continue |
|
|
| lora_key_b = lora_key_a.replace("lora_A", "lora_B") |
| if lora_key_b not in lora_state: |
| continue |
|
|
| A_weight = lora_state[lora_key_a]["weight"].float() |
| B_weight = lora_state[lora_key_b]["weight"].float() |
|
|
| |
| delta = (B_weight @ A_weight) * scaling |
| proj.weight.data += delta.to(proj.weight.dtype) |
| n_merged += 1 |
|
|
| log(f"Merged {n_merged} LoRA layers into base weights") |
|
|
| if n_merged == 0: |
| log("WARNING: No LoRA layers merged! Trying alternative key matching...") |
| log(f"Available LoRA keys: {list(lora_state.keys())[:10]}") |
| |
| lora_pairs = {} |
| for k, v in lora_state.items(): |
| base_key = k.replace(".lora_A", "").replace(".lora_B", "") |
| if base_key not in lora_pairs: |
| lora_pairs[base_key] = {} |
| if "lora_A" in k: |
| lora_pairs[base_key]["A"] = v |
| elif "lora_B" in k: |
| lora_pairs[base_key]["B"] = v |
|
|
| |
| target_layers = [] |
| for name, module in model.named_modules(): |
| for proj_name in ["q_proj", "v_proj"]: |
| if hasattr(module, proj_name): |
| proj = getattr(module, proj_name) |
| if isinstance(proj, nn.Linear): |
| target_layers.append((name, proj_name, proj)) |
|
|
| |
| sorted_pairs = sorted(lora_pairs.items()) |
| log(f"Found {len(sorted_pairs)} LoRA pairs, {len(target_layers)} target layers") |
|
|
| for (lora_key, pair), (name, proj_name, proj) in zip(sorted_pairs, target_layers): |
| if "A" in pair and "B" in pair: |
| A_weight = pair["A"]["weight"].float() |
| B_weight = pair["B"]["weight"].float() |
| delta = (B_weight @ A_weight) * scaling |
| proj.weight.data += delta.to(proj.weight.dtype) |
| n_merged += 1 |
|
|
| log(f"Merged {n_merged} LoRA layers (index matching)") |
|
|
| |
| log(f"Saving merged model to {args.output_dir}") |
| model.save_pretrained(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
| log(f"Merged model saved ({sum(f.stat().st_size for f in __import__('pathlib').Path(args.output_dir).rglob('*') if f.is_file()) / 1024**2:.0f} MB)") |
|
|
| |
| log("Converting to GGUF (Q8_0)...") |
| try: |
| |
| _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", "gguf"], check=True) |
|
|
| |
| result = _sp.run([ |
| sys.executable, "-m", "transformers", "gguf-export", |
| "--model", args.output_dir, |
| "--output", args.gguf_output, |
| "--quantize", "q8_0", |
| ], capture_output=True, text=True, timeout=300) |
|
|
| if result.returncode != 0: |
| log(f"transformers gguf-export failed: {result.stderr[:200]}") |
| |
| log("Trying llama.cpp converter...") |
| _sp.run(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", |
| "/tmp/llama.cpp"], capture_output=True, timeout=120) |
| _sp.run([sys.executable, "-m", "pip", "install", "--break-system-packages", "-q", |
| "-r", "/tmp/llama.cpp/requirements.txt"], capture_output=True, timeout=120) |
|
|
| |
| gguf_f16 = args.gguf_output.replace(".gguf", "-f16.gguf") |
| result = _sp.run([ |
| sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py", |
| args.output_dir, |
| "--outfile", gguf_f16, |
| "--outtype", "f16", |
| ], capture_output=True, text=True, timeout=300) |
| if result.returncode == 0: |
| log(f"GGUF F16 created: {gguf_f16}") |
| |
| q8_result = _sp.run([ |
| "/tmp/llama.cpp/build/bin/llama-quantize" if os.path.exists("/tmp/llama.cpp/build/bin/llama-quantize") else "echo", |
| gguf_f16, args.gguf_output, "q8_0" |
| ], capture_output=True, text=True, timeout=300) |
| if q8_result.returncode == 0: |
| log(f"GGUF Q8_0 created: {args.gguf_output}") |
| else: |
| log(f"Quantization failed, using F16: {gguf_f16}") |
| args.gguf_output = gguf_f16 |
| else: |
| log(f"GGUF conversion failed: {result.stderr[:300]}") |
| else: |
| log(f"GGUF created: {args.gguf_output}") |
|
|
| except Exception as e: |
| log(f"GGUF conversion error: {e}") |
|
|
| |
| log("") |
| log("Output files:") |
| for f in sorted(os.listdir(args.output_dir)): |
| path = os.path.join(args.output_dir, f) |
| if os.path.isfile(path): |
| size = os.path.getsize(path) |
| log(f" {f}: {size/1024**2:.1f} MB") |
|
|
| log("") |
| log("DONE") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|