Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation | |
| Official implementation of the paper: | |
| "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" | |
| by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis | |
| Licensed under a modified MIT license | |
| """ | |
| # Download and arrange PRIMA demo assets into the expected data/ layout. | |
| # Usage: | |
| # python scripts/setup_demo_data.py | |
| # python scripts/setup_demo_data.py --force | |
| from __future__ import annotations | |
| import argparse | |
| import shutil | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| DEFAULT_HF_REPO_ID = "MLAdaptiveIntelligence/PRIMA" | |
| SMAL_ASSET_PATHS = [ | |
| "my_smpl_00781_4_all.pkl", | |
| "my_smpl_data_00781_4_all.pkl", | |
| "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", | |
| ] | |
| BACKBONE_ASSET_PATH = "amr_vitbb.pth" | |
| STAGE1_CONFIG_ASSET_PATH = "config_s1_HYDRA.yaml" | |
| STAGE1_CHECKPOINT_ASSET_PATH = "s1ckpt.ckpt" | |
| STAGE3_CONFIG_ASSET_PATH = "config_s3_HYDRA.yaml" | |
| STAGE3_CHECKPOINT_ASSET_PATH = "s3ckpt.ckpt" | |
| def download_from_hub(hf_repo_id: str, remote_filename: str, dest: Path) -> None: | |
| """Download ``remote_filename`` from the Hub repo to exact path ``dest`` (resumable, uses HF cache).""" | |
| from huggingface_hub import hf_hub_download | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| got = hf_hub_download( | |
| repo_id=hf_repo_id, | |
| filename=remote_filename, | |
| local_dir=str(dest.parent), | |
| local_dir_use_symlinks=False, | |
| ) | |
| got_path = Path(got).resolve() | |
| target = dest.resolve() | |
| if got_path != target: | |
| if target.exists(): | |
| target.unlink() | |
| shutil.move(str(got_path), str(target)) | |
| def validate_torch_checkpoint(path: Path) -> None: | |
| try: | |
| torch.load(path, map_location="cpu") | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"Checkpoint file is invalid or incomplete: {path}\n" | |
| "Downloaded checkpoint is not loadable. " | |
| "Please verify the uploaded Hugging Face file and try again." | |
| ) from exc | |
| def maybe_download_backbone(data_dir: Path, force: bool, hf_repo_id: str) -> None: | |
| target = data_dir / "amr_vitbb.pth" | |
| if target.exists() and not force: | |
| print(f"[skip] {target} already exists") | |
| return | |
| print("[download] pretrained backbone") | |
| download_from_hub(hf_repo_id, BACKBONE_ASSET_PATH, target) | |
| print(f"[ok] {target}") | |
| def maybe_download_smal(data_dir: Path, force: bool, hf_repo_id: str) -> None: | |
| required = [Path(p).name for p in SMAL_ASSET_PATHS] | |
| smal_dir = data_dir / "smal" | |
| if smal_dir.exists() and all((smal_dir / n).exists() for n in required) and not force: | |
| print("[skip] SMAL files already exist") | |
| return | |
| print("[download] SMAL assets") | |
| for asset_path in SMAL_ASSET_PATHS: | |
| filename = Path(asset_path).name | |
| target = smal_dir / filename | |
| download_from_hub(hf_repo_id, asset_path, target) | |
| print(f"[ok] {smal_dir}") | |
| def maybe_download_stage( | |
| stage_name: str, | |
| config_asset_path: str, | |
| checkpoint_asset_path: str, | |
| ckpt_name: str, | |
| data_dir: Path, | |
| force: bool, | |
| hf_repo_id: str, | |
| ) -> None: | |
| stage_dir = data_dir / stage_name | |
| cfg_target = stage_dir / ".hydra" / "config.yaml" | |
| ckpt_target = stage_dir / "checkpoints" / ckpt_name | |
| existing_ckpt_valid = False | |
| if cfg_target.exists() and ckpt_target.exists() and not force: | |
| try: | |
| validate_torch_checkpoint(ckpt_target) | |
| existing_ckpt_valid = True | |
| except RuntimeError: | |
| print(f"[warn] {stage_name} checkpoint is incomplete, redownloading checkpoint only.") | |
| if cfg_target.exists() and existing_ckpt_valid and not force: | |
| print(f"[skip] {stage_name} assets already exist") | |
| return | |
| print(f"[download] {stage_name} assets") | |
| cfg_target.parent.mkdir(parents=True, exist_ok=True) | |
| ckpt_target.parent.mkdir(parents=True, exist_ok=True) | |
| if force or not cfg_target.exists(): | |
| download_from_hub(hf_repo_id, config_asset_path, cfg_target) | |
| if force or not ckpt_target.exists() or not existing_ckpt_valid: | |
| download_from_hub(hf_repo_id, checkpoint_asset_path, ckpt_target) | |
| validate_torch_checkpoint(ckpt_target) | |
| print(f"[ok] {stage_dir}") | |
| def verify_layout(data_dir: Path) -> None: | |
| required_paths = [ | |
| data_dir / "smal" / "my_smpl_00781_4_all.pkl", | |
| data_dir / "smal" / "my_smpl_data_00781_4_all.pkl", | |
| data_dir / "smal" / "walking_toy_symmetric_pose_prior_with_cov_35parts.pkl", | |
| data_dir / "amr_vitbb.pth", | |
| data_dir / "PRIMAS1" / ".hydra" / "config.yaml", | |
| data_dir / "PRIMAS1" / "checkpoints" / "s1ckpt.ckpt", | |
| data_dir / "PRIMAS3" / ".hydra" / "config.yaml", | |
| data_dir / "PRIMAS3" / "checkpoints" / "s3ckpt.ckpt", | |
| ] | |
| missing = [p for p in required_paths if not p.exists()] | |
| if missing: | |
| raise FileNotFoundError("Missing required files:\n" + "\n".join(str(p) for p in missing)) | |
| validate_torch_checkpoint(data_dir / "PRIMAS1" / "checkpoints" / "s1ckpt.ckpt") | |
| validate_torch_checkpoint(data_dir / "PRIMAS3" / "checkpoints" / "s3ckpt.ckpt") | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Download PRIMA demo checkpoints and data") | |
| parser.add_argument("--data-dir", type=Path, default=Path("data"), help="Target data directory") | |
| parser.add_argument("--force", action="store_true", help="Redownload and overwrite existing files") | |
| parser.add_argument( | |
| "--hf-repo-id", | |
| type=str, | |
| default=DEFAULT_HF_REPO_ID, | |
| help="Hugging Face repo ID containing demo assets (e.g., org/repo)", | |
| ) | |
| args = parser.parse_args() | |
| data_dir = args.data_dir.resolve() | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| maybe_download_smal(data_dir, force=args.force, hf_repo_id=args.hf_repo_id) | |
| maybe_download_backbone(data_dir, force=args.force, hf_repo_id=args.hf_repo_id) | |
| maybe_download_stage( | |
| "PRIMAS1", | |
| STAGE1_CONFIG_ASSET_PATH, | |
| STAGE1_CHECKPOINT_ASSET_PATH, | |
| "s1ckpt.ckpt", | |
| data_dir, | |
| force=args.force, | |
| hf_repo_id=args.hf_repo_id, | |
| ) | |
| maybe_download_stage( | |
| "PRIMAS3", | |
| STAGE3_CONFIG_ASSET_PATH, | |
| STAGE3_CHECKPOINT_ASSET_PATH, | |
| "s3ckpt.ckpt", | |
| data_dir, | |
| force=args.force, | |
| hf_repo_id=args.hf_repo_id, | |
| ) | |
| verify_layout(data_dir) | |
| print("\n[done] Demo assets ready.") | |
| print("Run demo:") | |
| print(" python demo.py --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt --img_folder demo_data/ --out_folder demo_out/") | |
| print("Run demo with TTA:") | |
| print(" python demo_tta.py --checkpoint data/PRIMAS1/checkpoints/s1ckpt.ckpt --img_folder demo_data/ --out_folder demo_out_tta/ --tta_lr 1e-6 --tta_num_iters 30") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |