math-conjecture-model / scripts /merge_and_push.py
NorthernTribe-Research's picture
Rename model repo target to math-conjecture-model and upload pipeline.
90dacf5 verified
#!/usr/bin/env python3
"""Merge a LoRA adapter into a full model and optionally push to Hugging Face."""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Optional, Tuple
import torch
from huggingface_hub import HfApi
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Merge a PEFT adapter into base weights and publish the merged model."
)
parser.add_argument(
"--adapter-path",
type=Path,
required=True,
help="Directory containing adapter_model.safetensors + adapter_config.json.",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Directory where merged weights are saved.",
)
parser.add_argument("--repo-id", type=str, default=None, help="Hub model repo id.")
parser.add_argument("--push-to-hub", action="store_true", help="Upload merged model to Hub.")
parser.add_argument("--private", action="store_true", help="Create private repo on Hub.")
parser.add_argument(
"--commit-message",
type=str,
default="Upload merged DeepSeek-Math conjecture model.",
)
parser.add_argument(
"--credentials-path",
type=Path,
default=Path("huggingface-api-key.json"),
help="Path to JSON credentials with {username, key}.",
)
parser.add_argument(
"--max-shard-size",
type=str,
default="5GB",
help="Shard size passed to save_pretrained.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trust_remote_code for tokenizer/model loading.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Load adapter in bfloat16 before merge (default float16).",
)
return parser.parse_args()
def as_text(value: object) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def resolve_auth(credentials_path: Path) -> Tuple[Optional[str], Optional[str]]:
token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
username = as_text(os.environ.get("HF_USERNAME")) or None
if credentials_path.exists():
data = json.loads(credentials_path.read_text(encoding="utf-8"))
if token is None:
token = as_text(data.get("key")) or None
if username is None:
username = as_text(data.get("username")) or None
return token, username
def merge_adapter(args: argparse.Namespace) -> None:
if not args.adapter_path.exists():
raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
dtype = torch.bfloat16 if args.bf16 else torch.float16
model = AutoPeftModelForCausalLM.from_pretrained(
str(args.adapter_path),
torch_dtype=dtype,
device_map="auto",
trust_remote_code=args.trust_remote_code,
)
merged = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(
str(args.adapter_path),
trust_remote_code=args.trust_remote_code,
)
args.output_dir.mkdir(parents=True, exist_ok=True)
merged.save_pretrained(
str(args.output_dir),
safe_serialization=True,
max_shard_size=args.max_shard_size,
)
tokenizer.save_pretrained(str(args.output_dir))
print(f"Merged model saved to: {args.output_dir}")
def push_merged(args: argparse.Namespace, token: str, repo_id: str) -> None:
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=args.private, exist_ok=True)
api.upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=str(args.output_dir),
commit_message=args.commit_message,
)
print(f"Pushed merged model to https://huggingface.co/{repo_id}")
def main() -> None:
args = parse_args()
merge_adapter(args)
if not args.push_to_hub:
return
token, username = resolve_auth(args.credentials_path)
if token is None:
raise ValueError("Missing HF token. Set HF_TOKEN or provide credentials JSON.")
repo_id = as_text(args.repo_id)
if not repo_id:
if not username:
raise ValueError("repo_id missing and username unavailable.")
repo_id = f"{username}/{args.output_dir.name}"
push_merged(args, token=token, repo_id=repo_id)
if __name__ == "__main__":
main()