File size: 2,564 Bytes
0f8411f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | # weights_utils.py
from __future__ import annotations
from pathlib import Path
from typing import Optional
import shutil
from huggingface_hub import hf_hub_download
# Your public weights repo
DEFAULT_WEIGHTS_REPO_ID = "sayehghp/vicca-weights"
def get_weight(
rel_path: str,
repo_id: str = DEFAULT_WEIGHTS_REPO_ID,
base_dir: Optional[Path] = None,
) -> str:
"""
Ensure `rel_path` exists locally and return its absolute path.
rel_path is a path relative to the app root, e.g.
"CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth"
"VG/weights/checkpoint0399.pth"
"""
if base_dir is None:
# /home/user/app in Spaces
base_dir = Path(__file__).parent
local_path = base_dir / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True)
# 1) If we've already mirrored it into the repo tree, reuse it.
if local_path.is_file():
return str(local_path)
# 2) Otherwise, fetch it from HF Hub (this uses HF's own cache).
if "/" in rel_path:
subfolder, filename = rel_path.rsplit("/", 1)
else:
subfolder, filename = None, rel_path
cached_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder if subfolder else None,
# IMPORTANT: don't set force_download=True
)
# 3) Mirror into our app tree so legacy code keeps working.
shutil.copy2(cached_path, local_path)
return str(local_path)
# All the heavy VICCA weights you listed
WEIGHT_FILES = [
# CheXbert
"CheXbert/checkpoint/chexbert.pth",
# Uniformer annotator
"CXRGen/annotator/ckpts/upernet_global_small.pth",
# CXRGen diffusion model
"CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
# BiomedVLP encoders
"CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin",
"VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin",
# Lung U-Nets
"CXRGen/LungDetection/models/unet-2v.pt",
"CXRGen/LungDetection/models/unet-6v.pt",
# DETR checkpoint
"DETR/output/checkpoint.pth",
# VG checkpoints
"VG/weights/checkpoint0399.pth",
"VG/weights/checkpoint0399_log4.pth",
"VG/weights/checkpoint_best_regular.pth",
]
def ensure_all_vicca_weights():
"""
Download all VICCA weights once per container.
This both:
* uses HF cache for each file
* mirrors them into the expected relative paths
"""
for rel_path in WEIGHT_FILES:
path = get_weight(rel_path)
print(f"[VICCA] Ensured weight: {rel_path} -> {path}", flush=True)
|