vicca / weights_utils.py
sayehghp's picture
Add application file
0f8411f
# weights_utils.py
from __future__ import annotations
from pathlib import Path
from typing import Optional
import shutil
from huggingface_hub import hf_hub_download
# Your public weights repo
DEFAULT_WEIGHTS_REPO_ID = "sayehghp/vicca-weights"
def get_weight(
rel_path: str,
repo_id: str = DEFAULT_WEIGHTS_REPO_ID,
base_dir: Optional[Path] = None,
) -> str:
"""
Ensure `rel_path` exists locally and return its absolute path.
rel_path is a path relative to the app root, e.g.
"CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth"
"VG/weights/checkpoint0399.pth"
"""
if base_dir is None:
# /home/user/app in Spaces
base_dir = Path(__file__).parent
local_path = base_dir / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True)
# 1) If we've already mirrored it into the repo tree, reuse it.
if local_path.is_file():
return str(local_path)
# 2) Otherwise, fetch it from HF Hub (this uses HF's own cache).
if "/" in rel_path:
subfolder, filename = rel_path.rsplit("/", 1)
else:
subfolder, filename = None, rel_path
cached_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder if subfolder else None,
# IMPORTANT: don't set force_download=True
)
# 3) Mirror into our app tree so legacy code keeps working.
shutil.copy2(cached_path, local_path)
return str(local_path)
# All the heavy VICCA weights you listed
WEIGHT_FILES = [
# CheXbert
"CheXbert/checkpoint/chexbert.pth",
# Uniformer annotator
"CXRGen/annotator/ckpts/upernet_global_small.pth",
# CXRGen diffusion model
"CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
# BiomedVLP encoders
"CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin",
"VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin",
# Lung U-Nets
"CXRGen/LungDetection/models/unet-2v.pt",
"CXRGen/LungDetection/models/unet-6v.pt",
# DETR checkpoint
"DETR/output/checkpoint.pth",
# VG checkpoints
"VG/weights/checkpoint0399.pth",
"VG/weights/checkpoint0399_log4.pth",
"VG/weights/checkpoint_best_regular.pth",
]
def ensure_all_vicca_weights():
"""
Download all VICCA weights once per container.
This both:
* uses HF cache for each file
* mirrors them into the expected relative paths
"""
for rel_path in WEIGHT_FILES:
path = get_weight(rel_path)
print(f"[VICCA] Ensured weight: {rel_path} -> {path}", flush=True)