kaveh commited on
Commit
57a673f
·
1 Parent(s): b0a0f7c

download ckp updated

Browse files
Files changed (1) hide show
  1. S2FApp/download_ckp.py +11 -1
S2FApp/download_ckp.py CHANGED
@@ -3,13 +3,23 @@ import os
3
  from pathlib import Path
4
 
5
  ckp = Path("/app/ckp")
6
- if not list(ckp.glob("*.pth")):
 
 
 
 
7
  try:
8
  from huggingface_hub import hf_hub_download, list_repo_files
9
 
10
  repo = os.environ.get("HF_MODEL_REPO", "kaveh/Shape2Force")
11
  files = list_repo_files(repo)
12
  pth_files = [f for f in files if f.startswith("ckp/") and f.endswith(".pth")]
 
 
 
 
 
 
13
  for f in pth_files:
14
  hf_hub_download(repo_id=repo, filename=f, local_dir="/app")
15
  print("Downloaded checkpoints from", repo)
 
3
  from pathlib import Path
4
 
5
  ckp = Path("/app/ckp")
6
+ ckp_single_cell = ckp / "single_cell"
7
+ ckp_spheroid = ckp / "spheroid"
8
+ has_any = list(ckp.glob("*.pth")) or list(ckp_single_cell.glob("*.pth")) or list(ckp_spheroid.glob("*.pth"))
9
+
10
+ if not has_any:
11
  try:
12
  from huggingface_hub import hf_hub_download, list_repo_files
13
 
14
  repo = os.environ.get("HF_MODEL_REPO", "kaveh/Shape2Force")
15
  files = list_repo_files(repo)
16
  pth_files = [f for f in files if f.startswith("ckp/") and f.endswith(".pth")]
17
+ # For spheroid: only download ckp_spheroid_FN.pth (not ckp_spheroid_GN.pth or others)
18
+ def should_download(f):
19
+ if "spheroid" in f and "ckp_spheroid_FN.pth" not in f:
20
+ return False
21
+ return True
22
+ pth_files = [f for f in pth_files if should_download(f)]
23
  for f in pth_files:
24
  hf_hub_download(repo_id=repo, filename=f, local_dir="/app")
25
  print("Downloaded checkpoints from", repo)