blitzkode / scripts /export_production.py
neuralbroker's picture
Add scripts/export_production.py
b56b94e verified
raw
history blame
7.58 kB
#!/usr/bin/env python3
"""Merge LoRA adapter β†’ HuggingFace model β†’ GGUF (blitzkode.gguf).
Pipeline
--------
1. Load base model + LoRA adapter, merge and unload adapters.
2. Save merged HuggingFace model to <output_dir>/merged/.
3. Download llama.cpp convert_hf_to_gguf.py from GitHub if not present.
4. Run GGUF conversion (quantised Q4_K_M by default).
5. Verify the output GGUF is loadable with llama-cpp-python.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import urllib.request
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CHECKPOINT = REPO_ROOT / "checkpoints" / "blitzkode-1.5b-lora" / "final"
DEFAULT_MERGED_DIR = REPO_ROOT / "exported" / "merged"
DEFAULT_GGUF_OUT = REPO_ROOT / "blitzkode.gguf"
LLAMA_CPP_SCRIPTS_DIR = REPO_ROOT / "llama.cpp"
CONVERT_SCRIPT_URL = (
"https://raw.githubusercontent.com/ggerganov/llama.cpp/master/convert_hf_to_gguf.py"
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--checkpoint", type=Path, default=DEFAULT_CHECKPOINT)
p.add_argument("--merged-dir", type=Path, default=DEFAULT_MERGED_DIR)
p.add_argument("--gguf-out", type=Path, default=DEFAULT_GGUF_OUT)
p.add_argument("--quant-type", default="q4_k_m", help="GGUF quantisation type (q4_k_m, q8_0, f16, …)")
p.add_argument("--skip-merge", action="store_true", help="Skip merge step; use --merged-dir as-is.")
p.add_argument("--skip-gguf", action="store_true", help="Skip GGUF conversion (only merge).")
p.add_argument("--verify", action="store_true", default=True, help="Verify GGUF is loadable (default: on).")
p.add_argument("--no-verify", dest="verify", action="store_false")
return p.parse_args()
# ─── Step 1: Merge ────────────────────────────────────────────────────────────
def merge_adapter(checkpoint: Path, merged_dir: Path) -> None:
print("\n[1/3] Merging LoRA adapter into base model …")
import torch # noqa: PLC0415
from peft import PeftModel # noqa: PLC0415
from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: PLC0415
config_path = checkpoint / "adapter_config.json"
if not config_path.exists():
raise SystemExit(f"adapter_config.json not found: {config_path}")
with config_path.open() as fh:
adapter_config = json.load(fh)
base_name = adapter_config["base_model_name_or_path"]
print(f" Checkpoint : {checkpoint}")
print(f" Base model : {base_name}")
dtype = torch.float16
print(" Loading base model …")
base = AutoModelForCausalLM.from_pretrained(
base_name,
dtype=dtype,
device_map="cpu",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(str(checkpoint), trust_remote_code=True)
print(" Loading & merging adapter …")
model = PeftModel.from_pretrained(base, str(checkpoint))
model = model.merge_and_unload()
merged_dir.mkdir(parents=True, exist_ok=True)
print(f" Saving merged model to {merged_dir} …")
model.save_pretrained(str(merged_dir))
tokenizer.save_pretrained(str(merged_dir))
print(" Merge complete.")
# ─── Step 2: Download convert script ─────────────────────────────────────────
def ensure_convert_script() -> Path:
LLAMA_CPP_SCRIPTS_DIR.mkdir(parents=True, exist_ok=True)
convert_script = LLAMA_CPP_SCRIPTS_DIR / "convert_hf_to_gguf.py"
if not convert_script.exists():
print(f"\n Downloading convert_hf_to_gguf.py from llama.cpp GitHub …")
try:
req = urllib.request.Request(CONVERT_SCRIPT_URL, headers={"User-Agent": "BlitzKode/2.0"})
with urllib.request.urlopen(req, timeout=60) as response:
content = response.read()
convert_script.write_bytes(content)
print(f" Saved to: {convert_script}")
except Exception as exc:
raise SystemExit(f"Failed to download convert_hf_to_gguf.py: {exc}") from exc
else:
print(f"\n Using cached: {convert_script}")
return convert_script
# ─── Step 3: GGUF conversion ──────────────────────────────────────────────────
def convert_to_gguf(merged_dir: Path, gguf_out: Path, quant_type: str) -> None:
print(f"\n[2/3] Converting to GGUF ({quant_type}) …")
convert_script = ensure_convert_script()
cmd = [
sys.executable,
str(convert_script),
str(merged_dir),
"--outfile",
str(gguf_out),
"--outtype",
quant_type,
]
print(f" Running: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=False, text=True)
if result.returncode != 0:
raise SystemExit(f"GGUF conversion failed (exit code {result.returncode}).")
print(f" GGUF written: {gguf_out} ({gguf_out.stat().st_size / 1024**3:.2f} GB)")
# ─── Step 4: Verify ───────────────────────────────────────────────────────────
def verify_gguf(gguf_path: Path) -> None:
print(f"\n[3/3] Verifying GGUF with llama-cpp-python …")
try:
import llama_cpp # noqa: PLC0415
llm = llama_cpp.Llama(
model_path=str(gguf_path),
n_ctx=128,
n_threads=2,
n_gpu_layers=0,
verbose=False,
)
prompt = "<|im_start|>user\nSay hello.<|im_end|>\n<|im_start|>assistant\n"
out = llm(prompt, max_tokens=8, stop=["<|im_end|>"])
text = out["choices"][0]["text"].strip()
print(f" Sample output: {text!r}")
print(" Verification PASSED.")
except Exception as exc:
print(f" [WARN] Verification raised: {exc}")
print(" GGUF was written; manual verification recommended.")
# ─── Main ─────────────────────────────────────────────────────────────────────
def main() -> None:
args = parse_args()
print("=" * 72)
print("BLITZKODE PRODUCTION EXPORT")
print("=" * 72)
if not args.skip_merge:
merge_adapter(args.checkpoint, args.merged_dir)
else:
print("\n[1/3] Merge skipped (--skip-merge).")
if not args.merged_dir.exists():
raise SystemExit(f"Merged dir not found: {args.merged_dir}")
if not args.skip_gguf:
convert_to_gguf(args.merged_dir, args.gguf_out, args.quant_type)
else:
print("\n[2/3] GGUF conversion skipped (--skip-gguf).")
if args.verify and args.gguf_out.exists():
verify_gguf(args.gguf_out)
print("\n" + "=" * 72)
print("EXPORT COMPLETE")
print(f" Merged HF model : {args.merged_dir}")
if args.gguf_out.exists():
print(f" GGUF model : {args.gguf_out}")
print("\nNext steps:")
print(" python scripts/push_to_hub.py")
print(" python server.py")
if __name__ == "__main__":
main()