File size: 6,429 Bytes
3f984f1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """Optional preprocessing from Model 22 (RAD-DINO notebook).
Thorax crop (ChestX-Det PSPNet), CLAHE, and light albumentations on PIL images.
Only imported when ``CFG.preprocessing_profile == \"model22\"`` and backbone is
``rad-dino`` — keeps default runs free of extra dependencies until enabled.
"""
from __future__ import annotations
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torchxrayvision as xrv
from PIL import Image
try:
import cv2
except ImportError as e: # pragma: no cover
cv2 = None # type: ignore[misc, assignment]
try:
import albumentations as A
except ImportError as e: # pragma: no cover
A = None # type: ignore[misc, assignment]
_SEG_MODEL: Optional[torch.nn.Module] = None
_BBOX_CACHE: Optional[Dict[str, Any]] = None
_BBOX_CACHE_PATH: Optional[str] = None
def _require_cv2() -> Any:
if cv2 is None:
raise ImportError(
"opencv-python-headless is required for Model 22 preprocessing (CLAHE). "
"Install with: pip install opencv-python-headless"
)
return cv2
def _require_albumentations() -> Any:
if A is None:
raise ImportError(
"albumentations is required for Model 22 medical augmentations. "
"Install with: pip install albumentations"
)
return A
def apply_clahe_pil(pil_img: Image.Image) -> Image.Image:
"""CLAHE on luminance (Model 22: clipLimit=2.0, 8×8 tiles)."""
cv = _require_cv2()
arr = np.array(pil_img.convert("L"), dtype=np.uint8)
clahe = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return Image.fromarray(clahe.apply(arr))
_medical_aug = None
def _get_medical_aug():
global _medical_aug
if _medical_aug is None:
Alb = _require_albumentations()
_medical_aug = Alb.Compose(
[
Alb.GridDistortion(num_steps=5, distort_limit=0.1, p=0.3),
Alb.ElasticTransform(alpha=30, sigma=5, p=0.3),
Alb.GaussNoise(var_limit=(5, 20), p=0.3),
Alb.Sharpen(alpha=(0.1, 0.3), p=0.3),
]
)
return _medical_aug
def augment_medical_pil(pil_img: Image.Image) -> Image.Image:
"""Albumentations on grayscale PIL (train only)."""
aug = _get_medical_aug()
arr = np.array(pil_img.convert("L"))
return Image.fromarray(aug(image=arr)["image"])
def _get_seg_model(device: torch.device) -> torch.nn.Module:
global _SEG_MODEL
if _SEG_MODEL is None:
print("Loading ChestX-Det PSPNet for thorax bounding boxes (Model 22)...")
_SEG_MODEL = xrv.baseline_models.chestx_det.PSPNet().to(device).eval()
return _SEG_MODEL
def _load_bbox_cache(cache_path: str) -> Dict[str, Any]:
global _BBOX_CACHE, _BBOX_CACHE_PATH
if _BBOX_CACHE is not None and _BBOX_CACHE_PATH == cache_path:
return _BBOX_CACHE
_BBOX_CACHE_PATH = cache_path
if os.path.isfile(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
_BBOX_CACHE = json.load(f)
else:
_BBOX_CACHE = {}
return _BBOX_CACHE
def _save_bbox_cache(cache_path: str) -> None:
if _BBOX_CACHE is None:
return
os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(_BBOX_CACHE, f)
def _compute_one_bbox(pil_img: Image.Image, device: torch.device, thorax_pad: float) -> List[int]:
W, H = pil_img.size
arr = np.array(pil_img, dtype=np.float32)
arr = xrv.datasets.normalize(arr, 255)
tensor = torch.from_numpy(arr[None, None, ...]).float().to(device)
seg_model = _get_seg_model(device)
with torch.no_grad():
out = seg_model(tensor)
seg = torch.sigmoid(out)[0]
mask = (seg[[4, 5, 8]].max(0).values > 0.5).cpu().numpy()
if not mask.any():
return [0, 0, W, H]
ys, xs = np.where(mask)
x0, x1 = int(xs.min()), int(xs.max()) + 1
y0, y1 = int(ys.min()), int(ys.max()) + 1
sx, sy = W / 512.0, H / 512.0
x0, x1 = int(round(x0 * sx)), int(round(x1 * sx))
y0, y1 = int(round(y0 * sy)), int(round(y1 * sy))
pad_x = int(round(thorax_pad * (x1 - x0)))
pad_y = int(round(thorax_pad * (y1 - y0)))
x0 = max(0, x0 - pad_x)
y0 = max(0, y0 - pad_y)
x1 = min(W, x1 + pad_x)
y1 = min(H, y1 + pad_y)
if x1 - x0 < 16 or y1 - y0 < 16:
return [0, 0, W, H]
return [x0, y0, x1, y1]
def ensure_thorax_bboxes(
image_paths: List[str],
cache_path: str,
device: str | torch.device,
thorax_pad: float = 0.05,
save_every: int = 64,
) -> None:
"""Populate JSON cache of thorax bboxes (absolute paths as keys)."""
if not image_paths:
return
dev = torch.device(device) if isinstance(device, str) else device
# PSPNet is small; CPU avoids competing with training on MPS/CUDA.
seg_dev = torch.device("cuda") if dev.type == "cuda" else torch.device("cpu")
cache = _load_bbox_cache(cache_path)
todo = [p for p in image_paths if p not in cache]
if not todo:
print(f"Thorax bbox cache up to date ({len(cache)} entries): {cache_path}")
return
print(f"Segmenting {len(todo)} image(s) for thorax crop (cache: {len(cache)})...")
for i, p in enumerate(todo, 1):
try:
img = Image.open(p).convert("L")
cache[p] = _compute_one_bbox(img, seg_dev, thorax_pad)
except Exception as e: # noqa: BLE001
print(f" bbox failed for {os.path.basename(p)}: {e!r} -> full image")
cache[p] = None
if i % save_every == 0:
_save_bbox_cache(cache_path)
print(f" flushed cache {i}/{len(todo)}")
global _BBOX_CACHE
_BBOX_CACHE = cache
_save_bbox_cache(cache_path)
print(f"Thorax bbox cache saved ({len(cache)} paths) → {cache_path}")
def crop_thorax_pil(pil_img: Image.Image, image_path: str, cache_path: str) -> Image.Image:
"""Crop PIL using cached bbox; full image if missing or invalid."""
cache = _load_bbox_cache(cache_path)
bbox = cache.get(image_path)
if bbox is None or not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
return pil_img
x0, y0, x1, y1 = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
W, H = pil_img.size
if x1 <= x0 or y1 <= y0 or x0 >= W or y0 >= H:
return pil_img
return pil_img.crop((x0, y0, x1, y1))
|