File size: 1,205 Bytes
2b4a83e
 
 
 
 
e8917d9
 
 
 
 
2b4a83e
 
 
cd71264
2b4a83e
 
e8917d9
 
 
 
 
 
2b4a83e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Download checkpoints from Hugging Face if ckp is empty (for Space deployment)."""
import os
from pathlib import Path

ckp = Path("/app/ckp")
ckp_single_cell = ckp / "single_cell"
ckp_spheroid = ckp / "spheroid"
has_any = list(ckp.glob("*.pth")) or list(ckp_single_cell.glob("*.pth")) or list(ckp_spheroid.glob("*.pth"))

if not has_any:
    try:
        from huggingface_hub import hf_hub_download, list_repo_files

        repo = os.environ.get("HF_MODEL_REPO", "Angione-Lab/Shape2Force")
        files = list_repo_files(repo)
        pth_files = [f for f in files if f.startswith("ckp/") and f.endswith(".pth")]
        # For spheroid: only download ckp_spheroid_FN.pth (not ckp_spheroid_GN.pth or others)
        def should_download(f):
            if "spheroid" in f and "ckp_spheroid_FN.pth" not in f:
                return False
            return True
        pth_files = [f for f in pth_files if should_download(f)]
        for f in pth_files:
            hf_hub_download(repo_id=repo, filename=f, local_dir="/app")
        print("Downloaded checkpoints from", repo)
    except Exception as e:
        print("Could not download checkpoints:", e)
else:
    print("Checkpoints already present")