| import os | |
| HF_REPO_ID = "VAST-AI/AniGen" | |
| CKPTS_DIR = "ckpts" | |
| REQUIRED_FILES = [ | |
| "ckpts/dinov2/hubconf.py", | |
| "ckpts/dinov2/dinov2/__init__.py", | |
| "ckpts/dsine/dsine.pt", | |
| "ckpts/vgg/vgg16-397923af.pth", | |
| ] | |
| def ensure_ckpts(local_dir: str = ".") -> None: | |
| missing = [ | |
| path for path in REQUIRED_FILES | |
| if not os.path.exists(os.path.join(local_dir, path)) | |
| ] | |
| if not missing: | |
| return | |
| print(f"Missing checkpoint files: {missing}") | |
| print(f"Downloading checkpoints from Hugging Face ({HF_REPO_ID}) ...") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=HF_REPO_ID, | |
| allow_patterns=[f"{CKPTS_DIR}/**"], | |
| local_dir=local_dir, | |
| ) | |
| print("Checkpoint download complete.") |