ObjectverseDiary / scripts /merge_lora_adapter.py
qqyule's picture
Deploy latest Objectverse Diary from fa09aac
dd6cefc verified
"""Merge an Objectverse Diary LoRA adapter into its base Hugging Face model."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin")
def validate_adapter_source(adapter: str | Path, *, base_model: str) -> dict[str, object]:
adapter_text = str(adapter)
adapter_path = Path(adapter_text)
if adapter_path.exists():
if not adapter_path.is_dir():
raise ValueError(f"Adapter path is not a directory: {adapter_path}")
config_path = adapter_path / "adapter_config.json"
if not config_path.exists():
raise ValueError(f"Adapter directory is missing adapter_config.json: {adapter_path}")
if not any((adapter_path / name).exists() for name in ADAPTER_WEIGHT_FILES):
raise ValueError(
"Adapter directory is missing adapter_model.safetensors or adapter_model.bin."
)
config = _read_adapter_config(config_path)
configured_base = config.get("base_model_name_or_path")
if configured_base and str(configured_base) != base_model:
raise ValueError(
f"Adapter base model is {configured_base!r}, expected {base_model!r}."
)
return {
"adapter": str(adapter_path),
"adapter_type": "local",
"adapter_base_model": configured_base or "",
}
if "/" not in adapter_text:
raise FileNotFoundError(f"Adapter source does not exist: {adapter_text}")
return {
"adapter": adapter_text,
"adapter_type": "hub",
"adapter_base_model": "",
}
def plan_merge(
*,
base_model: str,
adapter: str | Path,
output: Path,
dry_run: bool,
) -> dict[str, object]:
summary = validate_adapter_source(adapter, base_model=base_model)
summary.update(
{
"base_model": base_model,
"output": str(output),
"dry_run": dry_run,
}
)
if dry_run:
summary["merged"] = False
return summary
merge_lora_adapter(
base_model=base_model,
adapter=str(adapter),
output=output,
)
summary["merged"] = True
summary["files"] = sorted(path.name for path in output.iterdir() if path.is_file())
return summary
def merge_lora_adapter(
*,
base_model: str,
adapter: str,
output: Path,
) -> None:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
output.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype="auto",
device_map={"": "cpu"},
low_cpu_mem_usage=True,
)
peft_model = PeftModel.from_pretrained(model, adapter)
merged = peft_model.merge_and_unload(safe_merge=True)
merged.save_pretrained(
output,
safe_serialization=True,
max_shard_size="2GB",
)
tokenizer = AutoTokenizer.from_pretrained(adapter if Path(adapter).exists() else base_model)
tokenizer.save_pretrained(output)
metadata = {
"base_model": base_model,
"adapter": adapter,
"output": str(output),
"format": "merged-hf",
}
(output / "objectverse_merge_metadata.json").write_text(
json.dumps(metadata, indent=2, sort_keys=True),
encoding="utf-8",
)
def _read_adapter_config(config_path: Path) -> dict[str, object]:
try:
payload = json.loads(config_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid adapter_config.json: {exc.msg}") from exc
if not isinstance(payload, dict):
raise ValueError("adapter_config.json must contain a JSON object.")
return payload
def _print_json(payload: dict[str, Any]) -> None:
print(json.dumps(payload, indent=2, sort_keys=True), flush=True)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--base-model", required=True)
parser.add_argument("--adapter", required=True)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args()
def main() -> None:
args = _parse_args()
_print_json(
plan_merge(
base_model=args.base_model,
adapter=args.adapter,
output=args.output,
dry_run=args.dry_run,
)
)
if __name__ == "__main__":
try:
main()
except Exception as exc:
raise SystemExit(str(exc)) from exc