File size: 698 Bytes
6feb3b2
 
 
 
 
e992b8d
 
 
 
 
 
 
 
6feb3b2
 
e992b8d
6feb3b2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os
from pathlib import Path
from huggingface_hub import snapshot_download

def get_weights_dir(repo_id: str) -> Path:
    # repo_id must be like "BiasLab2025/taskclip-weights" (NOT a URL)
    repo_id = repo_id.strip()
    if repo_id.startswith("http"):
        # allow passing a full URL by accident
        repo_id = repo_id.rstrip("/").split("huggingface.co/")[-1]

    token = os.getenv("HF_TOKEN")  # only needed if the repo is private

    p = snapshot_download(
        repo_id=repo_id,
        repo_type="model",          # IMPORTANT for your weights repo
        local_dir="weights_cache",
        local_dir_use_symlinks=False,
        token=token,
    )
    return Path(p).resolve()