File size: 4,644 Bytes
90dacf5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | #!/usr/bin/env python3
"""Merge a LoRA adapter into a full model and optionally push to Hugging Face."""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Optional, Tuple
import torch
from huggingface_hub import HfApi
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Merge a PEFT adapter into base weights and publish the merged model."
)
parser.add_argument(
"--adapter-path",
type=Path,
required=True,
help="Directory containing adapter_model.safetensors + adapter_config.json.",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Directory where merged weights are saved.",
)
parser.add_argument("--repo-id", type=str, default=None, help="Hub model repo id.")
parser.add_argument("--push-to-hub", action="store_true", help="Upload merged model to Hub.")
parser.add_argument("--private", action="store_true", help="Create private repo on Hub.")
parser.add_argument(
"--commit-message",
type=str,
default="Upload merged DeepSeek-Math conjecture model.",
)
parser.add_argument(
"--credentials-path",
type=Path,
default=Path("huggingface-api-key.json"),
help="Path to JSON credentials with {username, key}.",
)
parser.add_argument(
"--max-shard-size",
type=str,
default="5GB",
help="Shard size passed to save_pretrained.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trust_remote_code for tokenizer/model loading.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Load adapter in bfloat16 before merge (default float16).",
)
return parser.parse_args()
def as_text(value: object) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def resolve_auth(credentials_path: Path) -> Tuple[Optional[str], Optional[str]]:
token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
username = as_text(os.environ.get("HF_USERNAME")) or None
if credentials_path.exists():
data = json.loads(credentials_path.read_text(encoding="utf-8"))
if token is None:
token = as_text(data.get("key")) or None
if username is None:
username = as_text(data.get("username")) or None
return token, username
def merge_adapter(args: argparse.Namespace) -> None:
if not args.adapter_path.exists():
raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
dtype = torch.bfloat16 if args.bf16 else torch.float16
model = AutoPeftModelForCausalLM.from_pretrained(
str(args.adapter_path),
torch_dtype=dtype,
device_map="auto",
trust_remote_code=args.trust_remote_code,
)
merged = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(
str(args.adapter_path),
trust_remote_code=args.trust_remote_code,
)
args.output_dir.mkdir(parents=True, exist_ok=True)
merged.save_pretrained(
str(args.output_dir),
safe_serialization=True,
max_shard_size=args.max_shard_size,
)
tokenizer.save_pretrained(str(args.output_dir))
print(f"Merged model saved to: {args.output_dir}")
def push_merged(args: argparse.Namespace, token: str, repo_id: str) -> None:
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=args.private, exist_ok=True)
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(args.output_dir),
commit_message=args.commit_message,
)
print(f"Pushed merged model to https://huggingface.co/{repo_id}")
def main() -> None:
args = parse_args()
merge_adapter(args)
if not args.push_to_hub:
return
token, username = resolve_auth(args.credentials_path)
if token is None:
raise ValueError("Missing HF token. Set HF_TOKEN or provide credentials JSON.")
repo_id = as_text(args.repo_id)
if not repo_id:
if not username:
raise ValueError("repo_id missing and username unavailable.")
repo_id = f"{username}/{args.output_dir.name}"
push_merged(args, token=token, repo_id=repo_id)
if __name__ == "__main__":
main()
|