Spaces:
Running on Zero
Running on Zero
| """Merge an Objectverse Diary LoRA adapter into its base Hugging Face model.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin") | |
| def validate_adapter_source(adapter: str | Path, *, base_model: str) -> dict[str, object]: | |
| adapter_text = str(adapter) | |
| adapter_path = Path(adapter_text) | |
| if adapter_path.exists(): | |
| if not adapter_path.is_dir(): | |
| raise ValueError(f"Adapter path is not a directory: {adapter_path}") | |
| config_path = adapter_path / "adapter_config.json" | |
| if not config_path.exists(): | |
| raise ValueError(f"Adapter directory is missing adapter_config.json: {adapter_path}") | |
| if not any((adapter_path / name).exists() for name in ADAPTER_WEIGHT_FILES): | |
| raise ValueError( | |
| "Adapter directory is missing adapter_model.safetensors or adapter_model.bin." | |
| ) | |
| config = _read_adapter_config(config_path) | |
| configured_base = config.get("base_model_name_or_path") | |
| if configured_base and str(configured_base) != base_model: | |
| raise ValueError( | |
| f"Adapter base model is {configured_base!r}, expected {base_model!r}." | |
| ) | |
| return { | |
| "adapter": str(adapter_path), | |
| "adapter_type": "local", | |
| "adapter_base_model": configured_base or "", | |
| } | |
| if "/" not in adapter_text: | |
| raise FileNotFoundError(f"Adapter source does not exist: {adapter_text}") | |
| return { | |
| "adapter": adapter_text, | |
| "adapter_type": "hub", | |
| "adapter_base_model": "", | |
| } | |
| def plan_merge( | |
| *, | |
| base_model: str, | |
| adapter: str | Path, | |
| output: Path, | |
| dry_run: bool, | |
| ) -> dict[str, object]: | |
| summary = validate_adapter_source(adapter, base_model=base_model) | |
| summary.update( | |
| { | |
| "base_model": base_model, | |
| "output": str(output), | |
| "dry_run": dry_run, | |
| } | |
| ) | |
| if dry_run: | |
| summary["merged"] = False | |
| return summary | |
| merge_lora_adapter( | |
| base_model=base_model, | |
| adapter=str(adapter), | |
| output=output, | |
| ) | |
| summary["merged"] = True | |
| summary["files"] = sorted(path.name for path in output.iterdir() if path.is_file()) | |
| return summary | |
| def merge_lora_adapter( | |
| *, | |
| base_model: str, | |
| adapter: str, | |
| output: Path, | |
| ) -> None: | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| output.mkdir(parents=True, exist_ok=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype="auto", | |
| device_map={"": "cpu"}, | |
| low_cpu_mem_usage=True, | |
| ) | |
| peft_model = PeftModel.from_pretrained(model, adapter) | |
| merged = peft_model.merge_and_unload(safe_merge=True) | |
| merged.save_pretrained( | |
| output, | |
| safe_serialization=True, | |
| max_shard_size="2GB", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(adapter if Path(adapter).exists() else base_model) | |
| tokenizer.save_pretrained(output) | |
| metadata = { | |
| "base_model": base_model, | |
| "adapter": adapter, | |
| "output": str(output), | |
| "format": "merged-hf", | |
| } | |
| (output / "objectverse_merge_metadata.json").write_text( | |
| json.dumps(metadata, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| def _read_adapter_config(config_path: Path) -> dict[str, object]: | |
| try: | |
| payload = json.loads(config_path.read_text(encoding="utf-8")) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Invalid adapter_config.json: {exc.msg}") from exc | |
| if not isinstance(payload, dict): | |
| raise ValueError("adapter_config.json must contain a JSON object.") | |
| return payload | |
| def _print_json(payload: dict[str, Any]) -> None: | |
| print(json.dumps(payload, indent=2, sort_keys=True), flush=True) | |
| def _parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--base-model", required=True) | |
| parser.add_argument("--adapter", required=True) | |
| parser.add_argument("--output", type=Path, required=True) | |
| parser.add_argument("--dry-run", action="store_true") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = _parse_args() | |
| _print_json( | |
| plan_merge( | |
| base_model=args.base_model, | |
| adapter=args.adapter, | |
| output=args.output, | |
| dry_run=args.dry_run, | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as exc: | |
| raise SystemExit(str(exc)) from exc | |