File size: 6,429 Bytes
3f984f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Optional preprocessing from Model 22 (RAD-DINO notebook).

Thorax crop (ChestX-Det PSPNet), CLAHE, and light albumentations on PIL images.
Only imported when ``CFG.preprocessing_profile == \"model22\"`` and backbone is
``rad-dino`` — keeps default runs free of extra dependencies until enabled.
"""

from __future__ import annotations

import json
import os
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torchxrayvision as xrv
from PIL import Image

try:
    import cv2
except ImportError as e:  # pragma: no cover
    cv2 = None  # type: ignore[misc, assignment]

try:
    import albumentations as A
except ImportError as e:  # pragma: no cover
    A = None  # type: ignore[misc, assignment]

_SEG_MODEL: Optional[torch.nn.Module] = None
_BBOX_CACHE: Optional[Dict[str, Any]] = None
_BBOX_CACHE_PATH: Optional[str] = None


def _require_cv2() -> Any:
    if cv2 is None:
        raise ImportError(
            "opencv-python-headless is required for Model 22 preprocessing (CLAHE). "
            "Install with: pip install opencv-python-headless"
        )
    return cv2


def _require_albumentations() -> Any:
    if A is None:
        raise ImportError(
            "albumentations is required for Model 22 medical augmentations. "
            "Install with: pip install albumentations"
        )
    return A


def apply_clahe_pil(pil_img: Image.Image) -> Image.Image:
    """CLAHE on luminance (Model 22: clipLimit=2.0, 8×8 tiles)."""
    cv = _require_cv2()
    arr = np.array(pil_img.convert("L"), dtype=np.uint8)
    clahe = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return Image.fromarray(clahe.apply(arr))


_medical_aug = None


def _get_medical_aug():
    global _medical_aug
    if _medical_aug is None:
        Alb = _require_albumentations()
        _medical_aug = Alb.Compose(
            [
                Alb.GridDistortion(num_steps=5, distort_limit=0.1, p=0.3),
                Alb.ElasticTransform(alpha=30, sigma=5, p=0.3),
                Alb.GaussNoise(var_limit=(5, 20), p=0.3),
                Alb.Sharpen(alpha=(0.1, 0.3), p=0.3),
            ]
        )
    return _medical_aug


def augment_medical_pil(pil_img: Image.Image) -> Image.Image:
    """Albumentations on grayscale PIL (train only)."""
    aug = _get_medical_aug()
    arr = np.array(pil_img.convert("L"))
    return Image.fromarray(aug(image=arr)["image"])


def _get_seg_model(device: torch.device) -> torch.nn.Module:
    global _SEG_MODEL
    if _SEG_MODEL is None:
        print("Loading ChestX-Det PSPNet for thorax bounding boxes (Model 22)...")
        _SEG_MODEL = xrv.baseline_models.chestx_det.PSPNet().to(device).eval()
    return _SEG_MODEL


def _load_bbox_cache(cache_path: str) -> Dict[str, Any]:
    global _BBOX_CACHE, _BBOX_CACHE_PATH
    if _BBOX_CACHE is not None and _BBOX_CACHE_PATH == cache_path:
        return _BBOX_CACHE
    _BBOX_CACHE_PATH = cache_path
    if os.path.isfile(cache_path):
        with open(cache_path, "r", encoding="utf-8") as f:
            _BBOX_CACHE = json.load(f)
    else:
        _BBOX_CACHE = {}
    return _BBOX_CACHE


def _save_bbox_cache(cache_path: str) -> None:
    if _BBOX_CACHE is None:
        return
    os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
    with open(cache_path, "w", encoding="utf-8") as f:
        json.dump(_BBOX_CACHE, f)


def _compute_one_bbox(pil_img: Image.Image, device: torch.device, thorax_pad: float) -> List[int]:
    W, H = pil_img.size
    arr = np.array(pil_img, dtype=np.float32)
    arr = xrv.datasets.normalize(arr, 255)
    tensor = torch.from_numpy(arr[None, None, ...]).float().to(device)
    seg_model = _get_seg_model(device)
    with torch.no_grad():
        out = seg_model(tensor)
    seg = torch.sigmoid(out)[0]
    mask = (seg[[4, 5, 8]].max(0).values > 0.5).cpu().numpy()
    if not mask.any():
        return [0, 0, W, H]
    ys, xs = np.where(mask)
    x0, x1 = int(xs.min()), int(xs.max()) + 1
    y0, y1 = int(ys.min()), int(ys.max()) + 1
    sx, sy = W / 512.0, H / 512.0
    x0, x1 = int(round(x0 * sx)), int(round(x1 * sx))
    y0, y1 = int(round(y0 * sy)), int(round(y1 * sy))
    pad_x = int(round(thorax_pad * (x1 - x0)))
    pad_y = int(round(thorax_pad * (y1 - y0)))
    x0 = max(0, x0 - pad_x)
    y0 = max(0, y0 - pad_y)
    x1 = min(W, x1 + pad_x)
    y1 = min(H, y1 + pad_y)
    if x1 - x0 < 16 or y1 - y0 < 16:
        return [0, 0, W, H]
    return [x0, y0, x1, y1]


def ensure_thorax_bboxes(
    image_paths: List[str],
    cache_path: str,
    device: str | torch.device,
    thorax_pad: float = 0.05,
    save_every: int = 64,
) -> None:
    """Populate JSON cache of thorax bboxes (absolute paths as keys)."""
    if not image_paths:
        return
    dev = torch.device(device) if isinstance(device, str) else device
    # PSPNet is small; CPU avoids competing with training on MPS/CUDA.
    seg_dev = torch.device("cuda") if dev.type == "cuda" else torch.device("cpu")
    cache = _load_bbox_cache(cache_path)
    todo = [p for p in image_paths if p not in cache]
    if not todo:
        print(f"Thorax bbox cache up to date ({len(cache)} entries): {cache_path}")
        return
    print(f"Segmenting {len(todo)} image(s) for thorax crop (cache: {len(cache)})...")
    for i, p in enumerate(todo, 1):
        try:
            img = Image.open(p).convert("L")
            cache[p] = _compute_one_bbox(img, seg_dev, thorax_pad)
        except Exception as e:  # noqa: BLE001
            print(f"  bbox failed for {os.path.basename(p)}: {e!r} -> full image")
            cache[p] = None
        if i % save_every == 0:
            _save_bbox_cache(cache_path)
            print(f"  flushed cache  {i}/{len(todo)}")
    global _BBOX_CACHE
    _BBOX_CACHE = cache
    _save_bbox_cache(cache_path)
    print(f"Thorax bbox cache saved ({len(cache)} paths) → {cache_path}")


def crop_thorax_pil(pil_img: Image.Image, image_path: str, cache_path: str) -> Image.Image:
    """Crop PIL using cached bbox; full image if missing or invalid."""
    cache = _load_bbox_cache(cache_path)
    bbox = cache.get(image_path)
    if bbox is None or not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
        return pil_img
    x0, y0, x1, y1 = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
    W, H = pil_img.size
    if x1 <= x0 or y1 <= y0 or x0 >= W or y0 >= H:
        return pil_img
    return pil_img.crop((x0, y0, x1, y1))