File size: 2,564 Bytes
0f8411f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# weights_utils.py
from __future__ import annotations

from pathlib import Path
from typing import Optional
import shutil

from huggingface_hub import hf_hub_download

# Your public weights repo
DEFAULT_WEIGHTS_REPO_ID = "sayehghp/vicca-weights"


def get_weight(
    rel_path: str,
    repo_id: str = DEFAULT_WEIGHTS_REPO_ID,
    base_dir: Optional[Path] = None,
) -> str:
    """
    Ensure `rel_path` exists locally and return its absolute path.

    rel_path is a path relative to the app root, e.g.
      "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth"
      "VG/weights/checkpoint0399.pth"
    """
    if base_dir is None:
        # /home/user/app in Spaces
        base_dir = Path(__file__).parent

    local_path = base_dir / rel_path
    local_path.parent.mkdir(parents=True, exist_ok=True)

    # 1) If we've already mirrored it into the repo tree, reuse it.
    if local_path.is_file():
        return str(local_path)

    # 2) Otherwise, fetch it from HF Hub (this uses HF's own cache).
    if "/" in rel_path:
        subfolder, filename = rel_path.rsplit("/", 1)
    else:
        subfolder, filename = None, rel_path

    cached_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder if subfolder else None,
        # IMPORTANT: don't set force_download=True
    )

    # 3) Mirror into our app tree so legacy code keeps working.
    shutil.copy2(cached_path, local_path)

    return str(local_path)


# All the heavy VICCA weights you listed
WEIGHT_FILES = [
    # CheXbert
    "CheXbert/checkpoint/chexbert.pth",

    # Uniformer annotator
    "CXRGen/annotator/ckpts/upernet_global_small.pth",

    # CXRGen diffusion model
    "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",

    # BiomedVLP encoders
    "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin",
    "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin",

    # Lung U-Nets
    "CXRGen/LungDetection/models/unet-2v.pt",
    "CXRGen/LungDetection/models/unet-6v.pt",

    # DETR checkpoint
    "DETR/output/checkpoint.pth",

    # VG checkpoints
    "VG/weights/checkpoint0399.pth",
    "VG/weights/checkpoint0399_log4.pth",
    "VG/weights/checkpoint_best_regular.pth",
]


def ensure_all_vicca_weights():
    """
    Download all VICCA weights once per container.

    This both:
      * uses HF cache for each file
      * mirrors them into the expected relative paths
    """
    for rel_path in WEIGHT_FILES:
        path = get_weight(rel_path)
        print(f"[VICCA] Ensured weight: {rel_path} -> {path}", flush=True)