# weights_utils.py from __future__ import annotations from pathlib import Path from typing import Optional import shutil from huggingface_hub import hf_hub_download # Your public weights repo DEFAULT_WEIGHTS_REPO_ID = "sayehghp/vicca-weights" def get_weight( rel_path: str, repo_id: str = DEFAULT_WEIGHTS_REPO_ID, base_dir: Optional[Path] = None, ) -> str: """ Ensure `rel_path` exists locally and return its absolute path. rel_path is a path relative to the app root, e.g. "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth" "VG/weights/checkpoint0399.pth" """ if base_dir is None: # /home/user/app in Spaces base_dir = Path(__file__).parent local_path = base_dir / rel_path local_path.parent.mkdir(parents=True, exist_ok=True) # 1) If we've already mirrored it into the repo tree, reuse it. if local_path.is_file(): return str(local_path) # 2) Otherwise, fetch it from HF Hub (this uses HF's own cache). if "/" in rel_path: subfolder, filename = rel_path.rsplit("/", 1) else: subfolder, filename = None, rel_path cached_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder if subfolder else None, # IMPORTANT: don't set force_download=True ) # 3) Mirror into our app tree so legacy code keeps working. shutil.copy2(cached_path, local_path) return str(local_path) # All the heavy VICCA weights you listed WEIGHT_FILES = [ # CheXbert "CheXbert/checkpoint/chexbert.pth", # Uniformer annotator "CXRGen/annotator/ckpts/upernet_global_small.pth", # CXRGen diffusion model "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth", # BiomedVLP encoders "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin", "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin", # Lung U-Nets "CXRGen/LungDetection/models/unet-2v.pt", "CXRGen/LungDetection/models/unet-6v.pt", # DETR checkpoint "DETR/output/checkpoint.pth", # VG checkpoints "VG/weights/checkpoint0399.pth", "VG/weights/checkpoint0399_log4.pth", "VG/weights/checkpoint_best_regular.pth", ] def ensure_all_vicca_weights(): """ Download all VICCA weights once per container. This both: * uses HF cache for each file * mirrors them into the expected relative paths """ for rel_path in WEIGHT_FILES: path = get_weight(rel_path) print(f"[VICCA] Ensured weight: {rel_path} -> {path}", flush=True)