ObjectverseDiary / scripts /publish_hf_adapter.py
qqyule's picture
Update Objectverse Diary submission package
9e874de verified
"""Upload a trained Objectverse Diary LoRA adapter folder to Hugging Face Hub."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any
REQUIRED_ADAPTER_FILES = ("adapter_config.json",)
ADAPTER_WEIGHT_FILES = ("adapter_model.safetensors", "adapter_model.bin")
def validate_adapter_dir(adapter_dir: Path) -> dict[str, object]:
if not adapter_dir.exists() or not adapter_dir.is_dir():
raise FileNotFoundError(f"Adapter directory does not exist: {adapter_dir}")
missing = [name for name in REQUIRED_ADAPTER_FILES if not (adapter_dir / name).exists()]
has_weights = any((adapter_dir / name).exists() for name in ADAPTER_WEIGHT_FILES)
if not has_weights:
missing.append("adapter_model.safetensors or adapter_model.bin")
if missing:
raise ValueError(f"Adapter directory is missing required files: {', '.join(missing)}")
files = sorted(path.name for path in adapter_dir.iterdir() if path.is_file())
return {
"adapter_dir": str(adapter_dir),
"files": files,
"file_count": len(files),
}
def upload_adapter(
*,
adapter_dir: Path,
repo_id: str,
private: bool,
commit_message: str,
dry_run: bool,
) -> dict[str, object]:
summary = validate_adapter_dir(adapter_dir)
summary.update(
{
"repo_id": repo_id,
"private": private,
"commit_message": commit_message,
"dry_run": dry_run,
}
)
if dry_run:
summary["uploaded"] = False
return summary
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
api.upload_folder(
folder_path=str(adapter_dir),
repo_id=repo_id,
repo_type="model",
commit_message=commit_message,
)
summary["uploaded"] = True
summary["url"] = f"https://huggingface.co/{repo_id}"
return summary
def _print_json(payload: dict[str, Any]) -> None:
print(json.dumps(payload, indent=2, sort_keys=True), flush=True)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--adapter-dir", type=Path, required=True)
parser.add_argument("--repo-id", required=True)
parser.add_argument("--private", action="store_true")
parser.add_argument(
"--commit-message",
default="Upload Objectverse Diary LoRA adapter",
)
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args()
def main() -> None:
args = _parse_args()
_print_json(
upload_adapter(
adapter_dir=args.adapter_dir,
repo_id=args.repo_id,
private=args.private,
commit_message=args.commit_message,
dry_run=args.dry_run,
)
)
if __name__ == "__main__":
try:
main()
except Exception as exc:
raise SystemExit(str(exc)) from exc