import cv2 import numpy as np import torch import albumentations as A from albumentations.pytorch import ToTensorV2 class LungPreprocessor: """Advanced Preprocessing pipeline for chest X-rays (Phase 4 Arch)""" def __init__(self, image_size=224): self.image_size = image_size def remove_artifacts_and_segment(self, image): """Remove text artifacts, borders, and segment lung field using Otsu & Morphological ops""" # Ensure grayscale if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image.copy() # 1. Edge cropping (remove typical artifact zones) h, w = gray.shape border = int(min(h, w) * 0.03) cropped = gray[border:h-border, border:w-border] # 2. Basic Lung Segmentation (Thresholding + Morphology) # Apply slight blur to remove noise blur = cv2.GaussianBlur(cropped, (5, 5), 0) # Otsu's thresholding to separate lungs (dark) from tissue (bright) # Note: Lungs are usually dark in X-rays _, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # Invert so lungs are white (lungs are darker than tissue usually, so OTSU might make them black. If so invert) # Lungs are dark, so thresholding usually makes dark areas 0. # We invert so lungs are 255 (white) mask = cv2.bitwise_not(thresh) # Morphological opening to remove small noise (like text/markers) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2) # Morphological closing to fill holes in lungs (like heart shadow) kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 30)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=2) # Smooth mask mask = cv2.GaussianBlur(mask, (21, 21), 0) # Apply mask but keep some context (blend with original) mask_float = mask.astype(float) / 255.0 segmented = cropped * mask_float + cropped * 0.2 * (1 - mask_float) # Keep 20% background context segmented = segmented.astype(np.uint8) # Resize back to target size to standardize before CLAHE standardized = cv2.resize(segmented, (self.image_size, self.image_size)) return standardized def apply_clahe(self, image): """Apply CLAHE for contrast enhancement""" clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) enhanced = clahe.apply(image) return enhanced def normalize_intensity(self, image): """Normalize intensity values out of extreme percentiles""" p2, p98 = np.percentile(image, (2, 98)) image = np.clip(image, p2, p98) image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8) return image def preprocess(self, image_path): """Full preprocessing pipeline""" image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) if image is None: raise ValueError(f"Failed to load image: {image_path}") image = self.remove_artifacts_and_segment(image) image = self.normalize_intensity(image) image = self.apply_clahe(image) return image def get_train_transforms(image_size=224): """Training augmentations - Advanced Medical Safe (Phase 5 Arch)""" return A.Compose([ A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.8), A.HorizontalFlip(p=0.5), A.Rotate(limit=10, p=0.7), A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5), A.GaussNoise(var_limit=(10.0, 50.0), p=0.3), A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2), A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0), ToTensorV2() ]) def get_val_transforms(image_size=224): """Validation transforms""" return A.Compose([ A.Resize(height=image_size, width=image_size), A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0), ToTensorV2() ]) class PreprocessedDataset(torch.utils.data.Dataset): """Dataset with preprocessing""" def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True): self.image_paths = image_paths self.labels = labels self.transforms = transforms self.use_preprocessing = use_preprocessing self.preprocessor = LungPreprocessor() if use_preprocessing else None def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img_path = self.image_paths[idx] label = self.labels[idx] if self.use_preprocessing: image = self.preprocessor.preprocess(img_path) else: image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) image = cv2.resize(image, (224, 224)) if self.transforms: augmented = self.transforms(image=image) image = augmented['image'] else: image = torch.from_numpy(image).unsqueeze(0).float() / 255.0 if image.shape[0] == 3: image = image.mean(dim=0, keepdim=True) elif image.dim() == 2: image = image.unsqueeze(0) return image, label