PRIMA-demo / scripts /setup_demo_data.py
HF Space deploy
Deploy snapshot (no PNG/JPG in git per HF policy)
873f551
#!/usr/bin/env python3
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
# Download and arrange PRIMA demo assets into the expected data/ layout.
# Usage:
# python scripts/setup_demo_data.py
# python scripts/setup_demo_data.py --force
from __future__ import annotations
import argparse
import shutil
import sys
from pathlib import Path
import torch
DEFAULT_HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA"
SMAL_ASSET_PATHS = [
"my_smpl_00781_4_all.pkl",
"my_smpl_data_00781_4_all.pkl",
"walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
]
BACKBONE_ASSET_PATH = "amr_vitbb.pth"
STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml"
STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt.ckpt"
STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml"
STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt.ckpt"
def download_from_hub(hf_repo_id: str, remote_filename: str, dest: Path) -> None:
"""Download ``remote_filename`` from the Hub repo to exact path ``dest`` (resumable, uses HF cache)."""
from huggingface_hub import hf_hub_download
dest.parent.mkdir(parents=True, exist_ok=True)
got = hf_hub_download(
repo_id=hf_repo_id,
filename=remote_filename,
local_dir=str(dest.parent),
local_dir_use_symlinks=False,
)
got_path = Path(got).resolve()
target = dest.resolve()
if got_path != target:
if target.exists():
target.unlink()
shutil.move(str(got_path), str(target))
def validate_torch_checkpoint(path: Path) -> None:
try:
torch.load(path, map_location="cpu")
except Exception as exc:
raise RuntimeError(
f"Checkpoint file is invalid or incomplete: {path}\n"
"Downloaded checkpoint is not loadable. "
"Please verify the uploaded Hugging Face file and try again."
) from exc
def maybe_download_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None:
target = data_dir / "amr_vitbb.pth"
if target.exists() and not force:
print(f"[skip] {target} already exists")
return
print("[download] pretrained backbone")
download_from_hub(hf_repo_id, BACKBONE_ASSET_PATH, target)
print(f"[ok] {target}")
def maybe_download_smal(data_dir: Path, force: bool, hf_repo_id: str) -> None:
required = [Path(p).name for p in SMAL_ASSET_PATHS]
smal_dir = data_dir / "smal"
if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force:
print("[skip] SMAL files already exist")
return
print("[download] SMAL assets")
for asset_path in SMAL_ASSET_PATHS:
filename = Path(asset_path).name
target = smal_dir / filename
download_from_hub(hf_repo_id, asset_path, target)
print(f"[ok] {smal_dir}")
def maybe_download_stage(
stage_name: str,
config_asset_path: str,
checkpoint_asset_path: str,
ckpt_name: str,
data_dir: Path,
force: bool,
hf_repo_id: str,
) -> None:
stage_dir = data_dir / stage_name
cfg_target = stage_dir / ".hydra" / "config.yaml"
ckpt_target = stage_dir / "checkpoints" / ckpt_name
existing_ckpt_valid = False
if cfg_target.exists() and ckpt_target.exists() and not force:
try:
validate_torch_checkpoint(ckpt_target)
existing_ckpt_valid = True
except RuntimeError:
print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.")
if cfg_target.exists() and existing_ckpt_valid and not force:
print(f"[skip] {stage_name} assets already exist")
return
print(f"[download] {stage_name} assets")
cfg_target.parent.mkdir(parents=True, exist_ok=True)
ckpt_target.parent.mkdir(parents=True, exist_ok=True)
if force or not cfg_target.exists():
download_from_hub(hf_repo_id, config_asset_path, cfg_target)
if force or not ckpt_target.exists() or not existing_ckpt_valid:
download_from_hub(hf_repo_id, checkpoint_asset_path, ckpt_target)
validate_torch_checkpoint(ckpt_target)
print(f"[ok] {stage_dir}")
def verify_layout(data_dir: Path) -> None:
required_paths = [
data_dir / "smal" / "my_smpl_00781_4_all.pkl",
data_dir / "smal" / "my_smpl_data_00781_4_all.pkl",
data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl",
data_dir / "amr_vitbb.pth",
data_dir / "PRIMAS1" / ".hydra" / "config.yaml",
data_dir / "PRIMAS1" / "checkpoints" / "s1ckpt.ckpt",
data_dir / "PRIMAS3" / ".hydra" / "config.yaml",
data_dir / "PRIMAS3" / "checkpoints" / "s3ckpt.ckpt",
]
missing = [p for p in required_paths if not p.exists()]
if missing:
raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing))
validate_torch_checkpoint(data_dir / "PRIMAS1" / "checkpoints" / "s1ckpt.ckpt")
validate_torch_checkpoint(data_dir / "PRIMAS3" / "checkpoints" / "s3ckpt.ckpt")
def main() -> int:
parser = argparse.ArgumentParser(description="Download PRIMA demo checkpoints and data")
parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Target data directory")
parser.add_argument("--force", action="store_true", help="Redownload and overwrite existing files")
parser.add_argument(
"--hf-repo-id",
type=str,
default=DEFAULT_HF_REPO_ID,
help="Hugging Face repo ID containing demo assets (e.g., org/repo)",
)
args = parser.parse_args()
data_dir = args.data_dir.resolve()
data_dir.mkdir(parents=True, exist_ok=True)
maybe_download_smal(data_dir, force=args.force, hf_repo_id=args.hf_repo_id)
maybe_download_backbone(data_dir, force=args.force, hf_repo_id=args.hf_repo_id)
maybe_download_stage(
"PRIMAS1",
STAGE1_CONFIG_ASSET_PATH,
STAGE1_CHECKPOINT_ASSET_PATH,
"s1ckpt.ckpt",
data_dir,
force=args.force,
hf_repo_id=args.hf_repo_id,
)
maybe_download_stage(
"PRIMAS3",
STAGE3_CONFIG_ASSET_PATH,
STAGE3_CHECKPOINT_ASSET_PATH,
"s3ckpt.ckpt",
data_dir,
force=args.force,
hf_repo_id=args.hf_repo_id,
)
verify_layout(data_dir)
print("\n[done] Demo assets ready.")
print("Run demo:")
print(" python demo.py --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt --img_folder demo_data/ --out_folder demo_out/")
print("Run demo with TTA:")
print(" python demo_tta.py --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt --img_folder demo_data/ --out_folder demo_out_tta/ --tta_lr 1e-6 --tta_num_iters 30")
return 0
if __name__ == "__main__":
raise SystemExit(main())