Spaces:
Running
Running
File size: 4,288 Bytes
cfd53db 7f1af80 cfd53db ac0940b 7f1af80 ac0940b 7f1af80 ac0940b cfd53db ac0940b cfd53db 7f1af80 ac0940b 7f1af80 ac0940b 7f1af80 ac0940b 7f1af80 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """Data augmentation and preprocessing pipelines.
Design rationale for each augmentation:
- RandomHorizontalFlip: chest X-rays have bilateral symmetry; flipping is anatomically valid.
- RandomRotation: slight patient positioning variation in real radiographs.
- RandomAffine (translate + shear): simulates positioning shifts and beam angle variation.
- ColorJitter (brightness/contrast): compensates for varying X-ray exposure settings.
- GaussianBlur: simulates varying sharpness due to patient motion or detector resolution.
- RandomErasing (CutOut): forces model to rely on distributed features, not single bright regions;
also simulates radio-opaque artifacts (leads, clips, implants).
- ImageNet normalization: even for grayscale medical images, ImageNet stats are standard when
using ImageNet-pretrained backbones (DenseNet-121). Both models use 3-channel RGB.
- CLAHE (optional): Contrast Limited Adaptive Histogram Equalisation enhances local contrast,
making low-contrast findings (nodules, infiltrations) more visible before the network sees them.
Applied in LAB colour space so brightness is enhanced without shifting colour balance.
"""
from __future__ import annotations
import numpy as np
from PIL import Image
from torchvision import transforms
class CLAHETransform:
"""Apply CLAHE to a PIL image to enhance local contrast.
Standard preprocessing in radiology AI — boosts visibility of small, low-contrast
findings (Nodule, Infiltration, Pneumonia) that are otherwise hard to learn from.
Applied in LAB colour space on the L (lightness) channel only.
"""
def __init__(self, clip_limit: float = 2.0, tile_grid_size: tuple[int, int] = (8, 8)) -> None:
self.clip_limit = clip_limit
self.tile_grid_size = tile_grid_size
def __call__(self, img: Image.Image) -> Image.Image:
import cv2 # lazy import — only required when CLAHE is enabled
img_np = np.array(img.convert("RGB"))
lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
result = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
return Image.fromarray(result)
def get_train_transforms(image_size: int = 320, use_clahe: bool = True) -> transforms.Compose:
"""Training transforms with medically-motivated data augmentation.
Args:
image_size: Target spatial resolution (both sides).
use_clahe: Prepend CLAHE contrast enhancement. Recommended for chest X-rays.
"""
steps: list = []
if use_clahe:
steps.append(CLAHETransform())
steps += [
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
# Slight translation (5%) and shear (5°) — patient positioning variation
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), shear=5),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.05),
# Simulate varying focus/motion blur in radiography equipment
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.2),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # ImageNet statistics
std=[0.229, 0.224, 0.225],
),
# CutOut: simulate radio-opaque objects; forces distributed feature learning
transforms.RandomErasing(p=0.1, scale=(0.02, 0.08), ratio=(0.5, 2.0)),
]
return transforms.Compose(steps)
def get_eval_transforms(image_size: int = 320, use_clahe: bool = True) -> transforms.Compose:
"""Evaluation/test transforms (no augmentation, optional CLAHE).
Args:
image_size: Target spatial resolution (both sides).
use_clahe: Prepend CLAHE contrast enhancement. Should match training setting.
"""
steps: list = []
if use_clahe:
steps.append(CLAHETransform())
steps += [
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
return transforms.Compose(steps)
|