TB-Guard-XAI / preprocessing.py
Vignesh-19's picture
Upload folder using huggingface_hub
a041069 verified
import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
class LungPreprocessor:
"""Advanced Preprocessing pipeline for chest X-rays (Phase 4 Arch)"""
def __init__(self, image_size=224):
self.image_size = image_size
def remove_artifacts_and_segment(self, image):
"""Remove text artifacts, borders, and segment lung field using Otsu & Morphological ops"""
# Ensure grayscale
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image.copy()
# 1. Edge cropping (remove typical artifact zones)
h, w = gray.shape
border = int(min(h, w) * 0.03)
cropped = gray[border:h-border, border:w-border]
# 2. Basic Lung Segmentation (Thresholding + Morphology)
# Apply slight blur to remove noise
blur = cv2.GaussianBlur(cropped, (5, 5), 0)
# Otsu's thresholding to separate lungs (dark) from tissue (bright)
# Note: Lungs are usually dark in X-rays
_, thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Invert so lungs are white (lungs are darker than tissue usually, so OTSU might make them black. If so invert)
# Lungs are dark, so thresholding usually makes dark areas 0.
# We invert so lungs are 255 (white)
mask = cv2.bitwise_not(thresh)
# Morphological opening to remove small noise (like text/markers)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
# Morphological closing to fill holes in lungs (like heart shadow)
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (30, 30))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=2)
# Smooth mask
mask = cv2.GaussianBlur(mask, (21, 21), 0)
# Apply mask but keep some context (blend with original)
mask_float = mask.astype(float) / 255.0
segmented = cropped * mask_float + cropped * 0.2 * (1 - mask_float) # Keep 20% background context
segmented = segmented.astype(np.uint8)
# Resize back to target size to standardize before CLAHE
standardized = cv2.resize(segmented, (self.image_size, self.image_size))
return standardized
def apply_clahe(self, image):
"""Apply CLAHE for contrast enhancement"""
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(image)
return enhanced
def normalize_intensity(self, image):
"""Normalize intensity values out of extreme percentiles"""
p2, p98 = np.percentile(image, (2, 98))
image = np.clip(image, p2, p98)
image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8)
return image
def preprocess(self, image_path):
"""Full preprocessing pipeline"""
image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
image = self.remove_artifacts_and_segment(image)
image = self.normalize_intensity(image)
image = self.apply_clahe(image)
return image
def get_train_transforms(image_size=224):
"""Training augmentations - Advanced Medical Safe (Phase 5 Arch)"""
return A.Compose([
A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.8),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=10, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2),
A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
ToTensorV2()
])
def get_val_transforms(image_size=224):
"""Validation transforms"""
return A.Compose([
A.Resize(height=image_size, width=image_size),
A.Normalize(mean=0.485, std=0.229, max_pixel_value=255.0),
ToTensorV2()
])
class PreprocessedDataset(torch.utils.data.Dataset):
"""Dataset with preprocessing"""
def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True):
self.image_paths = image_paths
self.labels = labels
self.transforms = transforms
self.use_preprocessing = use_preprocessing
self.preprocessor = LungPreprocessor() if use_preprocessing else None
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
if self.use_preprocessing:
image = self.preprocessor.preprocess(img_path)
else:
image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (224, 224))
if self.transforms:
augmented = self.transforms(image=image)
image = augmented['image']
else:
image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
if image.shape[0] == 3:
image = image.mean(dim=0, keepdim=True)
elif image.dim() == 2:
image = image.unsqueeze(0)
return image, label