| from __future__ import annotations |
|
|
| import os |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
| import torchvision.transforms as T |
|
|
| import torchxrayvision as xrv |
|
|
|
|
| |
| |
| |
| def xrv_normalize_np(pil_img: Image.Image) -> torch.Tensor: |
| """PIL grayscale → (1, H, W) float tensor in [-1024, 1024] (torchxrayvision).""" |
| arr = np.array(pil_img, dtype=np.float32) |
| arr = xrv.datasets.normalize(arr, 255) |
| arr = arr[None, ...] |
| return torch.from_numpy(arr).float() |
|
|
|
|
| _IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1) |
| _IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1) |
|
|
|
|
| def imagenet_normalize_np(pil_img: Image.Image) -> torch.Tensor: |
| """PIL grayscale → (3, H, W) float tensor normalized with ImageNet stats. |
| |
| The single grayscale channel is replicated to 3 channels so that ImageNet- |
| pretrained backbones (MobileNet, EfficientNet) receive the expected input shape. |
| """ |
| arr = np.array(pil_img, dtype=np.float32) / 255.0 |
| arr = np.stack([arr, arr, arr], axis=0) |
| arr = (arr - _IMAGENET_MEAN) / _IMAGENET_STD |
| return torch.from_numpy(arr).float() |
|
|
|
|
| |
| _RAD_DINO_PROCESSOR = None |
|
|
|
|
| def _get_rad_dino_processor(): |
| global _RAD_DINO_PROCESSOR |
| if _RAD_DINO_PROCESSOR is None: |
| from transformers import AutoImageProcessor |
| _RAD_DINO_PROCESSOR = AutoImageProcessor.from_pretrained("microsoft/rad-dino") |
| return _RAD_DINO_PROCESSOR |
|
|
|
|
| def rad_dino_normalize(pil_img: Image.Image) -> torch.Tensor: |
| """PIL image → (3, H, W) tensor using RAD-DINO's official AutoImageProcessor. |
| |
| Applies the exact same MIMIC-CXR normalization stats used during rad-dino |
| pretraining (mean/std provided by the HuggingFace processor config). |
| Grayscale images are converted to RGB by replicating the single channel |
| before passing to the processor. |
| """ |
| if pil_img.mode != "RGB": |
| pil_img = pil_img.convert("RGB") |
| proc = _get_rad_dino_processor() |
| out = proc(images=pil_img, return_tensors="pt") |
| return out["pixel_values"][0] |
|
|
|
|
| def get_normalize_fn(backbone: str): |
| """Return the correct normalization callable for the given backbone name. |
| |
| "densenet121" / "densenet121-res224-all" |
| → xrv_normalize_np (1-ch grayscale, [-1024, 1024]) |
| "rad-dino" |
| → rad_dino_normalize (3-ch RGB via AutoImageProcessor, MIMIC-CXR stats) |
| RAD-DINO is a ViT-B/14; feed at 518×518 for best accuracy. |
| all other torchvision backbones |
| → imagenet_normalize_np (3-ch RGB replicated, ImageNet stats) |
| """ |
| if backbone in ("densenet121", "densenet121-res224-all"): |
| return xrv_normalize_np |
| if backbone == "rad-dino": |
| return rad_dino_normalize |
| return imagenet_normalize_np |
|
|
|
|
| |
| |
| |
| class ChestXrayDataset(Dataset): |
| """Returns (image_tensor, label, filename) triples. |
| |
| backbone controls the normalization applied after PIL transforms: |
| "densenet121" → single-channel tensor in [-1024, 1024] (xrv) |
| any torchvision model → 3-channel tensor with ImageNet normalization |
| """ |
|
|
| def __init__( |
| self, |
| df: pd.DataFrame, |
| pil_transform=None, |
| use_erasing: bool = False, |
| backbone: str | None = None, |
| ) -> None: |
| from src.config import CFG |
| self.df = df.reset_index(drop=True) |
| self.pil_transform = pil_transform |
| self.use_erasing = use_erasing |
| self._normalize = get_normalize_fn(backbone or CFG.backbone) |
| self._erasing = T.RandomErasing( |
| p=0.5, scale=(0.02, 0.08), ratio=(0.3, 3.3), value=0 |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.df) |
|
|
| def __getitem__(self, idx: int): |
| row = self.df.iloc[idx] |
| img = Image.open(row["image_path"]).convert("L") |
| if self.pil_transform is not None: |
| img = self.pil_transform(img) |
| normalize = getattr(self, "_normalize", xrv_normalize_np) |
| tensor = normalize(img) |
| if self.use_erasing: |
| tensor = self._erasing(tensor) |
| label = torch.tensor(float(row["label"]), dtype=torch.float32) |
| return tensor, label, row["filename"] |
|
|
|
|
| |
| |
| |
| class TTADataset(Dataset): |
| """Used by inference passes (one TTA transform per pass).""" |
|
|
| def __init__( |
| self, |
| df: pd.DataFrame, |
| pil_transform, |
| image_dir: Optional[str] = None, |
| backbone: str | None = None, |
| ) -> None: |
| from src.config import CFG |
| self.df = df.reset_index(drop=True) |
| self.pil_transform = pil_transform |
| self.image_dir = image_dir |
| self._normalize = get_normalize_fn(backbone or CFG.backbone) |
|
|
| def __len__(self) -> int: |
| return len(self.df) |
|
|
| def __getitem__(self, idx: int): |
| row = self.df.iloc[idx] |
| if "image_path" in row and pd.notna(row.get("image_path")): |
| path = row["image_path"] |
| else: |
| path = os.path.join(self.image_dir, row["filename"]) |
| img = Image.open(path).convert("L") |
| img = self.pil_transform(img) |
| normalize = getattr(self, "_normalize", xrv_normalize_np) |
| tensor = normalize(img) |
| label = float(row["label"]) if "label" in row and not pd.isna(row.get("label", np.nan)) else 0.0 |
| name = row["filename"] if "filename" in row else os.path.basename(path) |
| return tensor, torch.tensor(label, dtype=torch.float32), name |
|
|
|
|
| |
| |
| |
| class SubmissionDataset(Dataset): |
| """Unlabelled test images for final inference. |
| |
| Returns (image_tensor, filename). |
| """ |
|
|
| def __init__( |
| self, |
| image_dir: str, |
| pil_transform=None, |
| backbone: str | None = None, |
| ) -> None: |
| from src.config import CFG |
| self.image_dir = image_dir |
| self.pil_transform = pil_transform |
| self._normalize = get_normalize_fn(backbone or CFG.backbone) |
| self.image_files = sorted( |
| f for f in os.listdir(image_dir) |
| if f.lower().endswith((".png", ".jpg", ".jpeg")) |
| ) |
|
|
| def __len__(self) -> int: |
| return len(self.image_files) |
|
|
| def __getitem__(self, idx: int): |
| fname = self.image_files[idx] |
| img = Image.open(os.path.join(self.image_dir, fname)).convert("L") |
| if self.pil_transform is not None: |
| img = self.pil_transform(img) |
| normalize = getattr(self, "_normalize", xrv_normalize_np) |
| tensor = normalize(img) |
| return tensor, fname |
|
|