day4model / model_training /src /model22_preprocess.py
cardio-deploy
Deploy CardioScan inference 2026-04-23T18:11:52Z
3f984f1
"""Optional preprocessing from Model 22 (RAD-DINO notebook).
Thorax crop (ChestX-Det PSPNet), CLAHE, and light albumentations on PIL images.
Only imported when ``CFG.preprocessing_profile == \"model22\"`` and backbone is
``rad-dino`` — keeps default runs free of extra dependencies until enabled.
"""
from __future__ import annotations
import json
import os
from typing import Any, Dict, List, Optional
import numpy as np
import torch
import torchxrayvision as xrv
from PIL import Image
try:
import cv2
except ImportError as e: # pragma: no cover
cv2 = None # type: ignore[misc, assignment]
try:
import albumentations as A
except ImportError as e: # pragma: no cover
A = None # type: ignore[misc, assignment]
_SEG_MODEL: Optional[torch.nn.Module] = None
_BBOX_CACHE: Optional[Dict[str, Any]] = None
_BBOX_CACHE_PATH: Optional[str] = None
def _require_cv2() -> Any:
if cv2 is None:
raise ImportError(
"opencv-python-headless is required for Model 22 preprocessing (CLAHE). "
"Install with: pip install opencv-python-headless"
)
return cv2
def _require_albumentations() -> Any:
if A is None:
raise ImportError(
"albumentations is required for Model 22 medical augmentations. "
"Install with: pip install albumentations"
)
return A
def apply_clahe_pil(pil_img: Image.Image) -> Image.Image:
"""CLAHE on luminance (Model 22: clipLimit=2.0, 8×8 tiles)."""
cv = _require_cv2()
arr = np.array(pil_img.convert("L"), dtype=np.uint8)
clahe = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return Image.fromarray(clahe.apply(arr))
_medical_aug = None
def _get_medical_aug():
global _medical_aug
if _medical_aug is None:
Alb = _require_albumentations()
_medical_aug = Alb.Compose(
[
Alb.GridDistortion(num_steps=5, distort_limit=0.1, p=0.3),
Alb.ElasticTransform(alpha=30, sigma=5, p=0.3),
Alb.GaussNoise(var_limit=(5, 20), p=0.3),
Alb.Sharpen(alpha=(0.1, 0.3), p=0.3),
]
)
return _medical_aug
def augment_medical_pil(pil_img: Image.Image) -> Image.Image:
"""Albumentations on grayscale PIL (train only)."""
aug = _get_medical_aug()
arr = np.array(pil_img.convert("L"))
return Image.fromarray(aug(image=arr)["image"])
def _get_seg_model(device: torch.device) -> torch.nn.Module:
global _SEG_MODEL
if _SEG_MODEL is None:
print("Loading ChestX-Det PSPNet for thorax bounding boxes (Model 22)...")
_SEG_MODEL = xrv.baseline_models.chestx_det.PSPNet().to(device).eval()
return _SEG_MODEL
def _load_bbox_cache(cache_path: str) -> Dict[str, Any]:
global _BBOX_CACHE, _BBOX_CACHE_PATH
if _BBOX_CACHE is not None and _BBOX_CACHE_PATH == cache_path:
return _BBOX_CACHE
_BBOX_CACHE_PATH = cache_path
if os.path.isfile(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
_BBOX_CACHE = json.load(f)
else:
_BBOX_CACHE = {}
return _BBOX_CACHE
def _save_bbox_cache(cache_path: str) -> None:
if _BBOX_CACHE is None:
return
os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(_BBOX_CACHE, f)
def _compute_one_bbox(pil_img: Image.Image, device: torch.device, thorax_pad: float) -> List[int]:
W, H = pil_img.size
arr = np.array(pil_img, dtype=np.float32)
arr = xrv.datasets.normalize(arr, 255)
tensor = torch.from_numpy(arr[None, None, ...]).float().to(device)
seg_model = _get_seg_model(device)
with torch.no_grad():
out = seg_model(tensor)
seg = torch.sigmoid(out)[0]
mask = (seg[[4, 5, 8]].max(0).values > 0.5).cpu().numpy()
if not mask.any():
return [0, 0, W, H]
ys, xs = np.where(mask)
x0, x1 = int(xs.min()), int(xs.max()) + 1
y0, y1 = int(ys.min()), int(ys.max()) + 1
sx, sy = W / 512.0, H / 512.0
x0, x1 = int(round(x0 * sx)), int(round(x1 * sx))
y0, y1 = int(round(y0 * sy)), int(round(y1 * sy))
pad_x = int(round(thorax_pad * (x1 - x0)))
pad_y = int(round(thorax_pad * (y1 - y0)))
x0 = max(0, x0 - pad_x)
y0 = max(0, y0 - pad_y)
x1 = min(W, x1 + pad_x)
y1 = min(H, y1 + pad_y)
if x1 - x0 < 16 or y1 - y0 < 16:
return [0, 0, W, H]
return [x0, y0, x1, y1]
def ensure_thorax_bboxes(
image_paths: List[str],
cache_path: str,
device: str | torch.device,
thorax_pad: float = 0.05,
save_every: int = 64,
) -> None:
"""Populate JSON cache of thorax bboxes (absolute paths as keys)."""
if not image_paths:
return
dev = torch.device(device) if isinstance(device, str) else device
# PSPNet is small; CPU avoids competing with training on MPS/CUDA.
seg_dev = torch.device("cuda") if dev.type == "cuda" else torch.device("cpu")
cache = _load_bbox_cache(cache_path)
todo = [p for p in image_paths if p not in cache]
if not todo:
print(f"Thorax bbox cache up to date ({len(cache)} entries): {cache_path}")
return
print(f"Segmenting {len(todo)} image(s) for thorax crop (cache: {len(cache)})...")
for i, p in enumerate(todo, 1):
try:
img = Image.open(p).convert("L")
cache[p] = _compute_one_bbox(img, seg_dev, thorax_pad)
except Exception as e: # noqa: BLE001
print(f" bbox failed for {os.path.basename(p)}: {e!r} -> full image")
cache[p] = None
if i % save_every == 0:
_save_bbox_cache(cache_path)
print(f" flushed cache {i}/{len(todo)}")
global _BBOX_CACHE
_BBOX_CACHE = cache
_save_bbox_cache(cache_path)
print(f"Thorax bbox cache saved ({len(cache)} paths) → {cache_path}")
def crop_thorax_pil(pil_img: Image.Image, image_path: str, cache_path: str) -> Image.Image:
"""Crop PIL using cached bbox; full image if missing or invalid."""
cache = _load_bbox_cache(cache_path)
bbox = cache.get(image_path)
if bbox is None or not isinstance(bbox, (list, tuple)) or len(bbox) != 4:
return pil_img
x0, y0, x1, y1 = (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))
W, H = pil_img.size
if x1 <= x0 or y1 <= y0 or x0 >= W or y0 >= H:
return pil_img
return pil_img.crop((x0, y0, x1, y1))