File size: 778 Bytes
6b92ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os


HF_REPO_ID = "VAST-AI/AniGen"
CKPTS_DIR = "ckpts"


REQUIRED_FILES = [
    "ckpts/dinov2/hubconf.py",
    "ckpts/dinov2/dinov2/__init__.py",
    "ckpts/dsine/dsine.pt",
    "ckpts/vgg/vgg16-397923af.pth",
]


def ensure_ckpts(local_dir: str = ".") -> None:
    missing = [
        path for path in REQUIRED_FILES
        if not os.path.exists(os.path.join(local_dir, path))
    ]
    if not missing:
        return

    print(f"Missing checkpoint files: {missing}")
    print(f"Downloading checkpoints from Hugging Face ({HF_REPO_ID}) ...")

    from huggingface_hub import snapshot_download

    snapshot_download(
        repo_id=HF_REPO_ID,
        allow_patterns=[f"{CKPTS_DIR}/**"],
        local_dir=local_dir,
    )

    print("Checkpoint download complete.")