| | |
| | from __future__ import annotations |
| |
|
| | from pathlib import Path |
| | from typing import Optional |
| | import shutil |
| |
|
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | DEFAULT_WEIGHTS_REPO_ID = "sayehghp/vicca-weights" |
| |
|
| |
|
| | def get_weight( |
| | rel_path: str, |
| | repo_id: str = DEFAULT_WEIGHTS_REPO_ID, |
| | base_dir: Optional[Path] = None, |
| | ) -> str: |
| | """ |
| | Ensure `rel_path` exists locally and return its absolute path. |
| | |
| | rel_path is a path relative to the app root, e.g. |
| | "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth" |
| | "VG/weights/checkpoint0399.pth" |
| | """ |
| | if base_dir is None: |
| | |
| | base_dir = Path(__file__).parent |
| |
|
| | local_path = base_dir / rel_path |
| | local_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | if local_path.is_file(): |
| | return str(local_path) |
| |
|
| | |
| | if "/" in rel_path: |
| | subfolder, filename = rel_path.rsplit("/", 1) |
| | else: |
| | subfolder, filename = None, rel_path |
| |
|
| | cached_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=filename, |
| | subfolder=subfolder if subfolder else None, |
| | |
| | ) |
| |
|
| | |
| | shutil.copy2(cached_path, local_path) |
| |
|
| | return str(local_path) |
| |
|
| |
|
| | |
| | WEIGHT_FILES = [ |
| | |
| | "CheXbert/checkpoint/chexbert.pth", |
| |
|
| | |
| | "CXRGen/annotator/ckpts/upernet_global_small.pth", |
| |
|
| | |
| | "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth", |
| |
|
| | |
| | "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin", |
| | "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin", |
| |
|
| | |
| | "CXRGen/LungDetection/models/unet-2v.pt", |
| | "CXRGen/LungDetection/models/unet-6v.pt", |
| |
|
| | |
| | "DETR/output/checkpoint.pth", |
| |
|
| | |
| | "VG/weights/checkpoint0399.pth", |
| | "VG/weights/checkpoint0399_log4.pth", |
| | "VG/weights/checkpoint_best_regular.pth", |
| | ] |
| |
|
| |
|
| | def ensure_all_vicca_weights(): |
| | """ |
| | Download all VICCA weights once per container. |
| | |
| | This both: |
| | * uses HF cache for each file |
| | * mirrors them into the expected relative paths |
| | """ |
| | for rel_path in WEIGHT_FILES: |
| | path = get_weight(rel_path) |
| | print(f"[VICCA] Ensured weight: {rel_path} -> {path}", flush=True) |
| |
|