| | import cv2
|
| | import numpy as np
|
| | import torch
|
| | import albumentations as A
|
| | from albumentations.pytorch import ToTensorV2
|
| |
|
| | class LungPreprocessor:
|
| | """Advanced Preprocessing pipeline for chest X-rays (Phase 4 Arch)"""
|
| |
|
| | def __init__(self, image_size=224):
|
| | self.image_size = image_size
|
| |
|
| | def remove_artifacts_and_segment(self, image):
|
| | """Remove text artifacts, borders, and segment lung field using Otsu & Morphological ops"""
|
| |
|
| | if len(image.shape) == 3:
|
| | gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
| | else:
|
| | gray = image.copy()
|
| |
|
| |
|
| | h, w = gray.shape
|
| | border = int(min(h, w) * 0.03)
|
| | cropped = gray[border:h-border, border:w-border]
|
| |
|
| |
|
| |
|
| | blur = cv2.GaussianBlur(cropped, (5, 5), 0)
|
| |
|
| |
|
| |
|
| | _, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| |
|
| |
|
| |
|
| |
|
| | mask = cv2.bitwise_not(thresh)
|
| |
|
| |
|
| | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
|
| | mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
|
| |
|
| |
|
| | kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 30))
|
| | mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=2)
|
| |
|
| |
|
| | mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
| |
|
| |
|
| | mask_float = mask.astype(float) / 255.0
|
| | segmented = cropped * mask_float + cropped * 0.2 * (1 - mask_float)
|
| | segmented = segmented.astype(np.uint8)
|
| |
|
| |
|
| | standardized = cv2.resize(segmented, (self.image_size, self.image_size))
|
| | return standardized
|
| |
|
| | def apply_clahe(self, image):
|
| | """Apply CLAHE for contrast enhancement"""
|
| | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| | enhanced = clahe.apply(image)
|
| | return enhanced
|
| |
|
| | def normalize_intensity(self, image):
|
| | """Normalize intensity values out of extreme percentiles"""
|
| | p2, p98 = np.percentile(image, (2, 98))
|
| | image = np.clip(image, p2, p98)
|
| | image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8)
|
| | return image
|
| |
|
| | def preprocess(self, image_path):
|
| | """Full preprocessing pipeline"""
|
| | image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
|
| | if image is None:
|
| | raise ValueError(f"Failed to load image: {image_path}")
|
| |
|
| | image = self.remove_artifacts_and_segment(image)
|
| | image = self.normalize_intensity(image)
|
| | image = self.apply_clahe(image)
|
| |
|
| | return image
|
| |
|
| | def get_train_transforms(image_size=224):
|
| | """Training augmentations - Advanced Medical Safe (Phase 5 Arch)"""
|
| | return A.Compose([
|
| | A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.8),
|
| | A.HorizontalFlip(p=0.5),
|
| | A.Rotate(limit=10, p=0.7),
|
| | A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
|
| | A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
|
| | A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2),
|
| | A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
|
| | ToTensorV2()
|
| | ])
|
| |
|
| | def get_val_transforms(image_size=224):
|
| | """Validation transforms"""
|
| | return A.Compose([
|
| | A.Resize(height=image_size, width=image_size),
|
| | A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
|
| | ToTensorV2()
|
| | ])
|
| |
|
| | class PreprocessedDataset(torch.utils.data.Dataset):
|
| | """Dataset with preprocessing"""
|
| | def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True):
|
| | self.image_paths = image_paths
|
| | self.labels = labels
|
| | self.transforms = transforms
|
| | self.use_preprocessing = use_preprocessing
|
| | self.preprocessor = LungPreprocessor() if use_preprocessing else None
|
| |
|
| | def __len__(self):
|
| | return len(self.image_paths)
|
| |
|
| | def __getitem__(self, idx):
|
| | img_path = self.image_paths[idx]
|
| | label = self.labels[idx]
|
| |
|
| | if self.use_preprocessing:
|
| | image = self.preprocessor.preprocess(img_path)
|
| | else:
|
| | image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
|
| | image = cv2.resize(image, (224, 224))
|
| |
|
| | if self.transforms:
|
| | augmented = self.transforms(image=image)
|
| | image = augmented['image']
|
| | else:
|
| | image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
|
| |
|
| | if image.shape[0] == 3:
|
| | image = image.mean(dim=0, keepdim=True)
|
| | elif image.dim() == 2:
|
| | image = image.unsqueeze(0)
|
| |
|
| | return image, label
|
| |
|