File size: 5,705 Bytes
a041069 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
class LungPreprocessor:
"""Advanced Preprocessing pipeline for chest X-rays (Phase 4 Arch)"""
def __init__(self, image_size=224):
self.image_size = image_size
def remove_artifacts_and_segment(self, image):
"""Remove text artifacts, borders, and segment lung field using Otsu & Morphological ops"""
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# 1. Edge cropping (remove typical artifact zones)
h, w = gray.shape
border = int(min(h, w) * 0.03)
cropped = gray[border:h-border, border:w-border]
# 2. Basic Lung Segmentation (Thresholding + Morphology)
# Apply slight blur to remove noise
blur = cv2.GaussianBlur(cropped, (5, 5), 0)
# Otsu's thresholding to separate lungs (dark) from tissue (bright)
# Note: Lungs are usually dark in X-rays
_, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Invert so lungs are white (lungs are darker than tissue usually, so OTSU might make them black. If so invert)
# Lungs are dark, so thresholding usually makes dark areas 0.
# We invert so lungs are 255 (white)
mask = cv2.bitwise_not(thresh)
# Morphological opening to remove small noise (like text/markers)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
# Morphological closing to fill holes in lungs (like heart shadow)
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 30))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=2)
# Smooth mask
mask = cv2.GaussianBlur(mask, (21, 21), 0)
# Apply mask but keep some context (blend with original)
mask_float = mask.astype(float) / 255.0
segmented = cropped * mask_float + cropped * 0.2 * (1 - mask_float) # Keep 20% background context
segmented = segmented.astype(np.uint8)
# Resize back to target size to standardize before CLAHE
standardized = cv2.resize(segmented, (self.image_size, self.image_size))
return standardized
def apply_clahe(self, image):
"""Apply CLAHE for contrast enhancement"""
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(image)
return enhanced
def normalize_intensity(self, image):
"""Normalize intensity values out of extreme percentiles"""
p2, p98 = np.percentile(image, (2, 98))
image = np.clip(image, p2, p98)
image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8)
return image
def preprocess(self, image_path):
"""Full preprocessing pipeline"""
image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
image = self.remove_artifacts_and_segment(image)
image = self.normalize_intensity(image)
image = self.apply_clahe(image)
return image
def get_train_transforms(image_size=224):
"""Training augmentations - Advanced Medical Safe (Phase 5 Arch)"""
return A.Compose([
A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.8),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=10, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2),
A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
ToTensorV2()
])
def get_val_transforms(image_size=224):
"""Validation transforms"""
return A.Compose([
A.Resize(height=image_size, width=image_size),
A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
ToTensorV2()
])
class PreprocessedDataset(torch.utils.data.Dataset):
"""Dataset with preprocessing"""
def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True):
self.image_paths = image_paths
self.labels = labels
self.transforms = transforms
self.use_preprocessing = use_preprocessing
self.preprocessor = LungPreprocessor() if use_preprocessing else None
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
if self.use_preprocessing:
image = self.preprocessor.preprocess(img_path)
else:
image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (224, 224))
if self.transforms:
augmented = self.transforms(image=image)
image = augmented['image']
else:
image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
if image.shape[0] == 3:
image = image.mean(dim=0, keepdim=True)
elif image.dim() == 2:
image = image.unsqueeze(0)
return image, label
|