vicca / weights_hub.py
sayehghp's picture
Add application file
0f8411f
from huggingface_hub import hf_hub_download
REPO_ID = "sayehghp/vicca-weights"
def download_vicca_weights(cache_dir: str = "weights_cache"):
"""
Download all VICCA-related weights from HF hub and return a dict of paths.
"""
paths = {}
# 1) CheXbert checkpoint
paths["chexbert"] = hf_hub_download(
repo_id=REPO_ID,
filename="chexbert.pth",
subfolder="CheXbert/checkpoint",
cache_dir=cache_dir,
)
# 2) CXRGen annotator (UPerNet)
paths["upernet_global_small"] = hf_hub_download(
repo_id=REPO_ID,
filename="upernet_global_small.pth",
subfolder="CXRGen/annotator/ckpts",
cache_dir=cache_dir,
)
# 3) CXRGen main diffusion checkpoint
paths["cxrgen_main"] = hf_hub_download(
repo_id=REPO_ID,
filename="cn_d25ofd18_epoch-v18.pth",
subfolder="CXRGen/checkpoints",
cache_dir=cache_dir,
)
# 4) BiomedVLP CXR BERT used by CXRGen
paths["biomedvlp_cxr_bert_cxrgen"] = hf_hub_download(
repo_id=REPO_ID,
filename="pytorch_model.bin",
subfolder="CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT",
cache_dir=cache_dir,
)
# 5) Lung detection U-Net models
paths["unet_2v"] = hf_hub_download(
repo_id=REPO_ID,
filename="unet-2v.pt",
subfolder="CXRGen/LungDetection/models",
cache_dir=cache_dir,
)
paths["unet_6v"] = hf_hub_download(
repo_id=REPO_ID,
filename="unet-6v.pt",
subfolder="CXRGen/LungDetection/models",
cache_dir=cache_dir,
)
# 6) DETR checkpoint (for shift computation)
paths["detr_checkpoint"] = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoint.pth",
subfolder="DETR/output",
cache_dir=cache_dir,
)
# 7) BiomedVLP CXR BERT used by VG
paths["biomedvlp_cxr_bert_vg"] = hf_hub_download(
repo_id=REPO_ID,
filename="pytorch_model.bin",
subfolder="VG/weights/BiomedVLP-CXR-BERT",
cache_dir=cache_dir,
)
# 8) Visual grounding / groundingDINO checkpoints
paths["vg_checkpoint0399"] = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoint0399.pth",
subfolder="VG/weights",
cache_dir=cache_dir,
)
paths["vg_checkpoint0399_log4"] = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoint0399_log4.pth",
subfolder="VG/weights",
cache_dir=cache_dir,
)
paths["vg_checkpoint_best_regular"] = hf_hub_download(
repo_id=REPO_ID,
filename="checkpoint_best_regular.pth",
subfolder="VG/weights",
cache_dir=cache_dir,
)
paths["groundingdino_swint_ogc"] = hf_hub_download(
repo_id=REPO_ID,
filename="groundingdino_swint_ogc.pth",
subfolder="VG/weights",
cache_dir=cache_dir,
)
return paths