"""Merge an Objectverse Diary LoRA adapter into its base Hugging Face model.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import Any ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin") def validate_adapter_source(adapter: str | Path, *, base_model: str) -> dict[str, object]: adapter_text = str(adapter) adapter_path = Path(adapter_text) if adapter_path.exists(): if not adapter_path.is_dir(): raise ValueError(f"Adapter path is not a directory: {adapter_path}") config_path = adapter_path / "adapter_config.json" if not config_path.exists(): raise ValueError(f"Adapter directory is missing adapter_config.json: {adapter_path}") if not any((adapter_path / name).exists() for name in ADAPTER_WEIGHT_FILES): raise ValueError( "Adapter directory is missing adapter_model.safetensors or adapter_model.bin." ) config = _read_adapter_config(config_path) configured_base = config.get("base_model_name_or_path") if configured_base and str(configured_base) != base_model: raise ValueError( f"Adapter base model is {configured_base!r}, expected {base_model!r}." ) return { "adapter": str(adapter_path), "adapter_type": "local", "adapter_base_model": configured_base or "", } if "/" not in adapter_text: raise FileNotFoundError(f"Adapter source does not exist: {adapter_text}") return { "adapter": adapter_text, "adapter_type": "hub", "adapter_base_model": "", } def plan_merge( *, base_model: str, adapter: str | Path, output: Path, dry_run: bool, ) -> dict[str, object]: summary = validate_adapter_source(adapter, base_model=base_model) summary.update( { "base_model": base_model, "output": str(output), "dry_run": dry_run, } ) if dry_run: summary["merged"] = False return summary merge_lora_adapter( base_model=base_model, adapter=str(adapter), output=output, ) summary["merged"] = True summary["files"] = sorted(path.name for path in output.iterdir() if path.is_file()) return summary def merge_lora_adapter( *, base_model: str, adapter: str, output: Path, ) -> None: from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer output.mkdir(parents=True, exist_ok=True) model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype="auto", device_map={"": "cpu"}, low_cpu_mem_usage=True, ) peft_model = PeftModel.from_pretrained(model, adapter) merged = peft_model.merge_and_unload(safe_merge=True) merged.save_pretrained( output, safe_serialization=True, max_shard_size="2GB", ) tokenizer = AutoTokenizer.from_pretrained(adapter if Path(adapter).exists() else base_model) tokenizer.save_pretrained(output) metadata = { "base_model": base_model, "adapter": adapter, "output": str(output), "format": "merged-hf", } (output / "objectverse_merge_metadata.json").write_text( json.dumps(metadata, indent=2, sort_keys=True), encoding="utf-8", ) def _read_adapter_config(config_path: Path) -> dict[str, object]: try: payload = json.loads(config_path.read_text(encoding="utf-8")) except json.JSONDecodeError as exc: raise ValueError(f"Invalid adapter_config.json: {exc.msg}") from exc if not isinstance(payload, dict): raise ValueError("adapter_config.json must contain a JSON object.") return payload def _print_json(payload: dict[str, Any]) -> None: print(json.dumps(payload, indent=2, sort_keys=True), flush=True) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--base-model", required=True) parser.add_argument("--adapter", required=True) parser.add_argument("--output", type=Path, required=True) parser.add_argument("--dry-run", action="store_true") return parser.parse_args() def main() -> None: args = _parse_args() _print_json( plan_merge( base_model=args.base_model, adapter=args.adapter, output=args.output, dry_run=args.dry_run, ) ) if __name__ == "__main__": try: main() except Exception as exc: raise SystemExit(str(exc)) from exc