| import cv2 |
| import numpy as np |
| import torch |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| from tqdm import tqdm |
| from pathlib import Path |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
| class LungSegmenter: |
| """Robust lung field segmentation and text/marker removal using CV techniques.""" |
|
|
| def __init__(self, image_size=224): |
| self.image_size = image_size |
|
|
| def segment_lungs(self, image): |
| """Segment lung fields from chest X-ray. Returns binary mask (0-255 uint8).""" |
| h, w = image.shape |
| if h == 0 or w == 0: |
| return np.ones_like(image, dtype=np.uint8) * 255 |
|
|
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) |
| enhanced = clahe.apply(image) |
|
|
| blur = cv2.GaussianBlur(enhanced, (7, 7), 0) |
|
|
| _, body_thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| body_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21)) |
| body_mask = cv2.morphologyEx(body_thresh, cv2.MORPH_CLOSE, body_kernel, iterations=3) |
| body_mask = cv2.morphologyEx(body_mask, cv2.MORPH_OPEN, body_kernel, iterations=1) |
|
|
| inverted = cv2.bitwise_not(blur) |
| masked_inverted = cv2.bitwise_and(inverted, inverted, mask=body_mask) |
| masked_inverted_blur = cv2.GaussianBlur(masked_inverted, (7, 7), 0) |
|
|
| _, lung_thresh = cv2.threshold( |
| masked_inverted_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU |
| ) |
|
|
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| lung_mask = cv2.morphologyEx(lung_thresh, cv2.MORPH_OPEN, kernel, iterations=2) |
|
|
| contours, _ = cv2.findContours(lung_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| min_area = h * w * 0.02 |
| max_area = h * w * 0.45 |
| valid_contours = [] |
| for cnt in contours: |
| area = cv2.contourArea(cnt) |
| if area < min_area or area > max_area: |
| continue |
| x, y, cw, ch = cv2.boundingRect(cnt) |
| if cw < 1 or ch < 1: |
| continue |
| aspect = cw / ch |
| if aspect < 0.3 or aspect > 2.5: |
| continue |
| if y > h * 0.85: |
| continue |
| valid_contours.append(cnt) |
|
|
| if not valid_contours: |
| return np.ones_like(image, dtype=np.uint8) * 255 |
|
|
| mask = np.zeros_like(image, dtype=np.uint8) |
| cv2.drawContours(mask, valid_contours, -1, 255, thickness=cv2.FILLED) |
|
|
| close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=3) |
|
|
| large_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31)) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, large_close, iterations=1) |
|
|
| mask = cv2.GaussianBlur(mask, (15, 15), 0) |
| mask = (mask > 50).astype(np.uint8) * 255 |
|
|
| mask_area_frac = mask.mean() / 255.0 |
| if mask_area_frac < 0.05 or mask_area_frac > 0.75: |
| return np.ones_like(image, dtype=np.uint8) * 255 |
|
|
| return mask |
|
|
| def remove_text_markers(self, image): |
| """Remove text annotations and hardware markers from X-ray.""" |
| h, w = image.shape |
| if h == 0 or w == 0: |
| return image |
|
|
| blur = cv2.GaussianBlur(image, (3, 3), 0) |
|
|
| thresh = cv2.adaptiveThreshold( |
| blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 |
| ) |
| thresh = cv2.bitwise_not(thresh) |
|
|
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)) |
| thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) |
|
|
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| mask = np.zeros_like(image, dtype=np.uint8) |
| for cnt in contours: |
| area = cv2.contourArea(cnt) |
| if area < 5 or area > h * w * 0.01: |
| continue |
| x, y, cw, ch = cv2.boundingRect(cnt) |
| if cw < 1 or ch < 1: |
| continue |
| aspect = cw / max(ch, 1) |
| is_thin_text = aspect > 5 or aspect < 0.2 |
| is_near_edge = ( |
| x < w * 0.03 |
| or x + cw > w * 0.97 |
| or y < h * 0.03 |
| or y + ch > h * 0.97 |
| ) |
| is_circular_marker = 0.8 < aspect < 1.2 and area < h * w * 0.002 |
|
|
| if is_thin_text or is_near_edge or is_circular_marker: |
| cv2.drawContours(mask, [cnt], -1, 255, thickness=cv2.FILLED) |
|
|
| if mask.sum() < 100: |
| return image |
|
|
| result = cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA) |
| return result |
|
|
|
|
| class LungPreprocessor: |
| def __init__(self, image_size=224): |
| self.image_size = image_size |
| self.segmenter = LungSegmenter(image_size) |
|
|
| def apply_clahe(self, image): |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
| return clahe.apply(image) |
|
|
| def normalize_intensity(self, image): |
| p2, p98 = np.percentile(image, (2, 98)) |
| image = np.clip(image, p2, p98) |
| image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8) |
| return image |
|
|
| def preprocess(self, image_path, segment_lung=True): |
| image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) |
| if image is None: |
| raise ValueError(f"Failed to load image: {image_path}") |
| if image.mean() < 1 or image.mean() > 253: |
| raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") |
|
|
| image = self.segmenter.remove_text_markers(image) |
|
|
| if segment_lung: |
| mask = self.segmenter.segment_lungs(image) |
| mask_float = mask.astype(float) / 255.0 |
| image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) |
|
|
| image = self.normalize_intensity(image) |
| image = self.apply_clahe(image) |
|
|
| pixel_range = float(image.max()) - float(image.min()) |
| if pixel_range < 2.0: |
| raise ValueError("Preprocessing produced near-uniform image") |
|
|
| return image |
|
|
| def preprocess_array(self, image, segment_lung=True): |
| """Preprocess an in-memory grayscale image array with optional lung segmentation.""" |
| if image is None: |
| raise ValueError("Input image is None") |
| if image.mean() < 1 or image.mean() > 253: |
| raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") |
|
|
| image = self.segmenter.remove_text_markers(image) |
|
|
| if segment_lung: |
| mask = self.segmenter.segment_lungs(image) |
| mask_float = mask.astype(float) / 255.0 |
| image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) |
|
|
| image = self.normalize_intensity(image) |
| image = self.apply_clahe(image) |
|
|
| pixel_range = float(image.max()) - float(image.min()) |
| if pixel_range < 2.0: |
| raise ValueError("Preprocessing produced near-uniform image") |
|
|
| return image |
|
|
| def preprocess_with_mask(self, image_path): |
| """Preprocess image and return (processed_image, lung_mask_at_original_resolution).""" |
| image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) |
| if image is None: |
| raise ValueError(f"Failed to load image: {image_path}") |
| if image.mean() < 1 or image.mean() > 253: |
| raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") |
|
|
| image = self.segmenter.remove_text_markers(image) |
|
|
| lung_mask = self.segmenter.segment_lungs(image) |
| mask_float = lung_mask.astype(float) / 255.0 |
| image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) |
|
|
| image = self.normalize_intensity(image) |
| image = self.apply_clahe(image) |
|
|
| pixel_range = float(image.max()) - float(image.min()) |
| if pixel_range < 2.0: |
| raise ValueError("Preprocessing produced near-uniform image") |
|
|
| return image, lung_mask |
|
|
|
|
| def get_train_transforms(image_size=224): |
| return A.Compose([ |
| A.Resize(height=image_size, width=image_size), |
| A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.5), |
| A.HorizontalFlip(p=0.5), |
| A.Rotate(limit=10, p=0.7), |
| A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5), |
| A.GaussNoise(var_limit=(10.0, 50.0), p=0.3), |
| A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2), |
| A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0), |
| ToTensorV2() |
| ]) |
|
|
|
|
| def get_val_transforms(image_size=224): |
| return A.Compose([ |
| A.Resize(height=image_size, width=image_size), |
| A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0), |
| ToTensorV2() |
| ]) |
|
|
|
|
| class PreprocessedDataset(torch.utils.data.Dataset): |
| def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True, cache_pt=True, load_to_ram=False): |
| self.image_paths = image_paths |
| self.labels = labels |
| self.transforms = transforms |
| self.use_preprocessing = use_preprocessing |
| self.cache_pt = cache_pt |
| self.preprocessor = LungPreprocessor() if use_preprocessing else None |
| self.ram_cache = None |
|
|
| if load_to_ram: |
| self._preload_to_ram() |
|
|
| def _get_preprocessed(self, img_path, target_size=224): |
| if not self.use_preprocessing: |
| image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) |
| image = cv2.resize(image, (target_size, target_size)) |
| return image |
|
|
| cache_path = img_path.with_suffix('.pt.cache') |
| if self.cache_pt and cache_path.exists(): |
| return torch.load(cache_path, weights_only=True).numpy() |
|
|
| image = self.preprocessor.preprocess(str(img_path), segment_lung=True) |
|
|
| if self.cache_pt: |
| torch.save(torch.from_numpy(image), cache_path) |
|
|
| return image |
|
|
| def _build_cache(self, paths): |
| processor = LungPreprocessor() |
| for p in paths: |
| try: |
| img = processor.preprocess(str(p), segment_lung=True) |
| torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache')) |
| except Exception: |
| pass |
|
|
| def _preload_to_ram(self): |
| n_total = len(self.image_paths) |
| print(f" Pre-loading {n_total} images into RAM...") |
|
|
| uncached = [p for p in self.image_paths if not p.with_suffix('.pt.cache').exists()] |
| if uncached: |
| n_workers = min(8, len(uncached)) |
| chunk_size = (len(uncached) + n_workers - 1) // n_workers |
| chunks = [uncached[i:i+chunk_size] for i in range(0, len(uncached), chunk_size)] |
| print(f" Processing {len(uncached)} uncached images with {n_workers} workers...") |
| worker = LungPreprocessor() |
| for chunk in tqdm(chunks, desc="Caching"): |
| for p in chunk: |
| try: |
| img = worker.preprocess(str(p), segment_lung=True) |
| torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache')) |
| except Exception: |
| pass |
|
|
| self.ram_cache = [] |
| target_size = None |
| if self.transforms: |
| for t in self.transforms: |
| if hasattr(t, 'height'): |
| target_size = t.height |
| break |
| if target_size is None: |
| target_size = self.preprocessor.image_size if self.preprocessor else 224 |
|
|
| for path in tqdm(self.image_paths, desc="Loading RAM"): |
| try: |
| image = torch.load(path.with_suffix('.pt.cache'), weights_only=True).numpy() |
| self.ram_cache.append(image) |
| except Exception as e: |
| print(f" Warning: Failed to load {path}: {e}") |
| self.ram_cache.append(np.zeros((target_size, target_size), dtype=np.uint8)) |
|
|
| n = len(self.ram_cache) |
| size_gb = n * target_size * target_size * 1 / 1e9 |
| print(f" Cached {n} images in RAM (~{size_gb:.1f} GB)") |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| label = self.labels[idx] |
|
|
| if self.ram_cache is not None: |
| image = self.ram_cache[idx] |
| elif self.use_preprocessing: |
| image = self._get_preprocessed(self.image_paths[idx]) |
| else: |
| image = cv2.imread(str(self.image_paths[idx]), cv2.IMREAD_GRAYSCALE) |
| target_size = self.transforms[0].height if self.transforms and hasattr(self.transforms[0], 'height') else 224 |
| image = cv2.resize(image, (target_size, target_size)) |
|
|
| if self.transforms: |
| augmented = self.transforms(image=image) |
| image = augmented['image'] |
| else: |
| image = torch.from_numpy(image).unsqueeze(0).float() / 255.0 |
|
|
| if image.shape[0] == 3: |
| image = image.mean(dim=0, keepdim=True) |
| elif image.dim() == 2: |
| image = image.unsqueeze(0) |
|
|
| return image, label |
|
|