import cv2 import numpy as np import torch import albumentations as A from albumentations.pytorch import ToTensorV2 from tqdm import tqdm from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed class LungSegmenter: """Robust lung field segmentation and text/marker removal using CV techniques.""" def __init__(self, image_size=224): self.image_size = image_size def segment_lungs(self, image): """Segment lung fields from chest X-ray. Returns binary mask (0-255 uint8).""" h, w = image.shape if h == 0 or w == 0: return np.ones_like(image, dtype=np.uint8) * 255 clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) enhanced = clahe.apply(image) blur = cv2.GaussianBlur(enhanced, (7, 7), 0) _, body_thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) body_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21)) body_mask = cv2.morphologyEx(body_thresh, cv2.MORPH_CLOSE, body_kernel, iterations=3) body_mask = cv2.morphologyEx(body_mask, cv2.MORPH_OPEN, body_kernel, iterations=1) inverted = cv2.bitwise_not(blur) masked_inverted = cv2.bitwise_and(inverted, inverted, mask=body_mask) masked_inverted_blur = cv2.GaussianBlur(masked_inverted, (7, 7), 0) _, lung_thresh = cv2.threshold( masked_inverted_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU ) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) lung_mask = cv2.morphologyEx(lung_thresh, cv2.MORPH_OPEN, kernel, iterations=2) contours, _ = cv2.findContours(lung_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) min_area = h * w * 0.02 max_area = h * w * 0.45 valid_contours = [] for cnt in contours: area = cv2.contourArea(cnt) if area < min_area or area > max_area: continue x, y, cw, ch = cv2.boundingRect(cnt) if cw < 1 or ch < 1: continue aspect = cw / ch if aspect < 0.3 or aspect > 2.5: continue if y > h * 0.85: continue valid_contours.append(cnt) if not valid_contours: return np.ones_like(image, dtype=np.uint8) * 255 mask = np.zeros_like(image, dtype=np.uint8) cv2.drawContours(mask, valid_contours, -1, 255, thickness=cv2.FILLED) close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=3) large_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, large_close, iterations=1) mask = cv2.GaussianBlur(mask, (15, 15), 0) mask = (mask > 50).astype(np.uint8) * 255 mask_area_frac = mask.mean() / 255.0 if mask_area_frac < 0.05 or mask_area_frac > 0.75: return np.ones_like(image, dtype=np.uint8) * 255 return mask def remove_text_markers(self, image): """Remove text annotations and hardware markers from X-ray.""" h, w = image.shape if h == 0 or w == 0: return image blur = cv2.GaussianBlur(image, (3, 3), 0) thresh = cv2.adaptiveThreshold( blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 ) thresh = cv2.bitwise_not(thresh) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2)) thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask = np.zeros_like(image, dtype=np.uint8) for cnt in contours: area = cv2.contourArea(cnt) if area < 5 or area > h * w * 0.01: continue x, y, cw, ch = cv2.boundingRect(cnt) if cw < 1 or ch < 1: continue aspect = cw / max(ch, 1) is_thin_text = aspect > 5 or aspect < 0.2 is_near_edge = ( x < w * 0.03 or x + cw > w * 0.97 or y < h * 0.03 or y + ch > h * 0.97 ) is_circular_marker = 0.8 < aspect < 1.2 and area < h * w * 0.002 if is_thin_text or is_near_edge or is_circular_marker: cv2.drawContours(mask, [cnt], -1, 255, thickness=cv2.FILLED) if mask.sum() < 100: return image result = cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA) return result class LungPreprocessor: def __init__(self, image_size=224): self.image_size = image_size self.segmenter = LungSegmenter(image_size) def apply_clahe(self, image): clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) return clahe.apply(image) def normalize_intensity(self, image): p2, p98 = np.percentile(image, (2, 98)) image = np.clip(image, p2, p98) image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8) return image def preprocess(self, image_path, segment_lung=True): image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) if image is None: raise ValueError(f"Failed to load image: {image_path}") if image.mean() < 1 or image.mean() > 253: raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") image = self.segmenter.remove_text_markers(image) if segment_lung: mask = self.segmenter.segment_lungs(image) mask_float = mask.astype(float) / 255.0 image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) image = self.normalize_intensity(image) image = self.apply_clahe(image) pixel_range = float(image.max()) - float(image.min()) if pixel_range < 2.0: raise ValueError("Preprocessing produced near-uniform image") return image def preprocess_array(self, image, segment_lung=True): """Preprocess an in-memory grayscale image array with optional lung segmentation.""" if image is None: raise ValueError("Input image is None") if image.mean() < 1 or image.mean() > 253: raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") image = self.segmenter.remove_text_markers(image) if segment_lung: mask = self.segmenter.segment_lungs(image) mask_float = mask.astype(float) / 255.0 image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) image = self.normalize_intensity(image) image = self.apply_clahe(image) pixel_range = float(image.max()) - float(image.min()) if pixel_range < 2.0: raise ValueError("Preprocessing produced near-uniform image") return image def preprocess_with_mask(self, image_path): """Preprocess image and return (processed_image, lung_mask_at_original_resolution).""" image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) if image is None: raise ValueError(f"Failed to load image: {image_path}") if image.mean() < 1 or image.mean() > 253: raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})") image = self.segmenter.remove_text_markers(image) lung_mask = self.segmenter.segment_lungs(image) mask_float = lung_mask.astype(float) / 255.0 image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8) image = self.normalize_intensity(image) image = self.apply_clahe(image) pixel_range = float(image.max()) - float(image.min()) if pixel_range < 2.0: raise ValueError("Preprocessing produced near-uniform image") return image, lung_mask def get_train_transforms(image_size=224): return A.Compose([ A.Resize(height=image_size, width=image_size), A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.5), A.HorizontalFlip(p=0.5), A.Rotate(limit=10, p=0.7), A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5), A.GaussNoise(var_limit=(10.0, 50.0), p=0.3), A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2), A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0), ToTensorV2() ]) def get_val_transforms(image_size=224): return A.Compose([ A.Resize(height=image_size, width=image_size), A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0), ToTensorV2() ]) class PreprocessedDataset(torch.utils.data.Dataset): def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True, cache_pt=True, load_to_ram=False): self.image_paths = image_paths self.labels = labels self.transforms = transforms self.use_preprocessing = use_preprocessing self.cache_pt = cache_pt self.preprocessor = LungPreprocessor() if use_preprocessing else None self.ram_cache = None if load_to_ram: self._preload_to_ram() def _get_preprocessed(self, img_path, target_size=224): if not self.use_preprocessing: image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) image = cv2.resize(image, (target_size, target_size)) return image cache_path = img_path.with_suffix('.pt.cache') if self.cache_pt and cache_path.exists(): return torch.load(cache_path, weights_only=True).numpy() image = self.preprocessor.preprocess(str(img_path), segment_lung=True) if self.cache_pt: torch.save(torch.from_numpy(image), cache_path) return image def _build_cache(self, paths): processor = LungPreprocessor() for p in paths: try: img = processor.preprocess(str(p), segment_lung=True) torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache')) except Exception: pass def _preload_to_ram(self): n_total = len(self.image_paths) print(f" Pre-loading {n_total} images into RAM...") uncached = [p for p in self.image_paths if not p.with_suffix('.pt.cache').exists()] if uncached: n_workers = min(8, len(uncached)) chunk_size = (len(uncached) + n_workers - 1) // n_workers chunks = [uncached[i:i+chunk_size] for i in range(0, len(uncached), chunk_size)] print(f" Processing {len(uncached)} uncached images with {n_workers} workers...") worker = LungPreprocessor() for chunk in tqdm(chunks, desc="Caching"): for p in chunk: try: img = worker.preprocess(str(p), segment_lung=True) torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache')) except Exception: pass self.ram_cache = [] target_size = None if self.transforms: for t in self.transforms: if hasattr(t, 'height'): target_size = t.height break if target_size is None: target_size = self.preprocessor.image_size if self.preprocessor else 224 for path in tqdm(self.image_paths, desc="Loading RAM"): try: image = torch.load(path.with_suffix('.pt.cache'), weights_only=True).numpy() self.ram_cache.append(image) except Exception as e: print(f" Warning: Failed to load {path}: {e}") self.ram_cache.append(np.zeros((target_size, target_size), dtype=np.uint8)) n = len(self.ram_cache) size_gb = n * target_size * target_size * 1 / 1e9 print(f" Cached {n} images in RAM (~{size_gb:.1f} GB)") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): label = self.labels[idx] if self.ram_cache is not None: image = self.ram_cache[idx] elif self.use_preprocessing: image = self._get_preprocessed(self.image_paths[idx]) else: image = cv2.imread(str(self.image_paths[idx]), cv2.IMREAD_GRAYSCALE) target_size = self.transforms[0].height if self.transforms and hasattr(self.transforms[0], 'height') else 224 image = cv2.resize(image, (target_size, target_size)) if self.transforms: augmented = self.transforms(image=image) image = augmented['image'] else: image = torch.from_numpy(image).unsqueeze(0).float() / 255.0 if image.shape[0] == 3: image = image.mean(dim=0, keepdim=True) elif image.dim() == 2: image = image.unsqueeze(0) return image, label