TB-Guard / preprocessing.py
Vignesh19's picture
Upload preprocessing.py with huggingface_hub
7b60ec5 verified
Raw
History Blame Contribute Delete
13.2 kB
import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
class LungSegmenter:
"""Robust lung field segmentation and text/marker removal using CV techniques."""
def __init__(self, image_size=224):
self.image_size = image_size
def segment_lungs(self, image):
"""Segment lung fields from chest X-ray. Returns binary mask (0-255 uint8)."""
h, w = image.shape
if h == 0 or w == 0:
return np.ones_like(image, dtype=np.uint8) * 255
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(image)
blur = cv2.GaussianBlur(enhanced, (7, 7), 0)
_, body_thresh = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
body_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (21, 21))
body_mask = cv2.morphologyEx(body_thresh, cv2.MORPH_CLOSE, body_kernel, iterations=3)
body_mask = cv2.morphologyEx(body_mask, cv2.MORPH_OPEN, body_kernel, iterations=1)
inverted = cv2.bitwise_not(blur)
masked_inverted = cv2.bitwise_and(inverted, inverted, mask=body_mask)
masked_inverted_blur = cv2.GaussianBlur(masked_inverted, (7, 7), 0)
_, lung_thresh = cv2.threshold(
masked_inverted_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
lung_mask = cv2.morphologyEx(lung_thresh, cv2.MORPH_OPEN, kernel, iterations=2)
contours, _ = cv2.findContours(lung_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
min_area = h * w * 0.02
max_area = h * w * 0.45
valid_contours = []
for cnt in contours:
area = cv2.contourArea(cnt)
if area < min_area or area > max_area:
continue
x, y, cw, ch = cv2.boundingRect(cnt)
if cw < 1 or ch < 1:
continue
aspect = cw / ch
if aspect < 0.3 or aspect > 2.5:
continue
if y > h * 0.85:
continue
valid_contours.append(cnt)
if not valid_contours:
return np.ones_like(image, dtype=np.uint8) * 255
mask = np.zeros_like(image, dtype=np.uint8)
cv2.drawContours(mask, valid_contours, -1, 255, thickness=cv2.FILLED)
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=3)
large_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (31, 31))
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, large_close, iterations=1)
mask = cv2.GaussianBlur(mask, (15, 15), 0)
mask = (mask > 50).astype(np.uint8) * 255
mask_area_frac = mask.mean() / 255.0
if mask_area_frac < 0.05 or mask_area_frac > 0.75:
return np.ones_like(image, dtype=np.uint8) * 255
return mask
def remove_text_markers(self, image):
"""Remove text annotations and hardware markers from X-ray."""
h, w = image.shape
if h == 0 or w == 0:
return image
blur = cv2.GaussianBlur(image, (3, 3), 0)
thresh = cv2.adaptiveThreshold(
blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
)
thresh = cv2.bitwise_not(thresh)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2, 2))
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=1)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
mask = np.zeros_like(image, dtype=np.uint8)
for cnt in contours:
area = cv2.contourArea(cnt)
if area < 5 or area > h * w * 0.01:
continue
x, y, cw, ch = cv2.boundingRect(cnt)
if cw < 1 or ch < 1:
continue
aspect = cw / max(ch, 1)
is_thin_text = aspect > 5 or aspect < 0.2
is_near_edge = (
x < w * 0.03
or x + cw > w * 0.97
or y < h * 0.03
or y + ch > h * 0.97
)
is_circular_marker = 0.8 < aspect < 1.2 and area < h * w * 0.002
if is_thin_text or is_near_edge or is_circular_marker:
cv2.drawContours(mask, [cnt], -1, 255, thickness=cv2.FILLED)
if mask.sum() < 100:
return image
result = cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA)
return result
class LungPreprocessor:
def __init__(self, image_size=224):
self.image_size = image_size
self.segmenter = LungSegmenter(image_size)
def apply_clahe(self, image):
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return clahe.apply(image)
def normalize_intensity(self, image):
p2, p98 = np.percentile(image, (2, 98))
image = np.clip(image, p2, p98)
image = ((image - image.min()) / (max(1e-8, image.max() - image.min())) * 255).astype(np.uint8)
return image
def preprocess(self, image_path, segment_lung=True):
image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
if image.mean() < 1 or image.mean() > 253:
raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})")
image = self.segmenter.remove_text_markers(image)
if segment_lung:
mask = self.segmenter.segment_lungs(image)
mask_float = mask.astype(float) / 255.0
image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8)
image = self.normalize_intensity(image)
image = self.apply_clahe(image)
pixel_range = float(image.max()) - float(image.min())
if pixel_range < 2.0:
raise ValueError("Preprocessing produced near-uniform image")
return image
def preprocess_array(self, image, segment_lung=True):
"""Preprocess an in-memory grayscale image array with optional lung segmentation."""
if image is None:
raise ValueError("Input image is None")
if image.mean() < 1 or image.mean() > 253:
raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})")
image = self.segmenter.remove_text_markers(image)
if segment_lung:
mask = self.segmenter.segment_lungs(image)
mask_float = mask.astype(float) / 255.0
image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8)
image = self.normalize_intensity(image)
image = self.apply_clahe(image)
pixel_range = float(image.max()) - float(image.min())
if pixel_range < 2.0:
raise ValueError("Preprocessing produced near-uniform image")
return image
def preprocess_with_mask(self, image_path):
"""Preprocess image and return (processed_image, lung_mask_at_original_resolution)."""
image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
if image is None:
raise ValueError(f"Failed to load image: {image_path}")
if image.mean() < 1 or image.mean() > 253:
raise ValueError(f"Image appears corrupted (mean intensity: {image.mean():.1f})")
image = self.segmenter.remove_text_markers(image)
lung_mask = self.segmenter.segment_lungs(image)
mask_float = lung_mask.astype(float) / 255.0
image = (image * mask_float + image * 0.1 * (1 - mask_float)).astype(np.uint8)
image = self.normalize_intensity(image)
image = self.apply_clahe(image)
pixel_range = float(image.max()) - float(image.min())
if pixel_range < 2.0:
raise ValueError("Preprocessing produced near-uniform image")
return image, lung_mask
def get_train_transforms(image_size=224):
return A.Compose([
A.Resize(height=image_size, width=image_size),
A.RandomResizedCrop(size=(image_size, image_size), scale=(0.85, 1.0), p=0.5),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=10, p=0.7),
A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
A.GridDistortion(num_steps=5, distort_limit=0.1, p=0.2),
A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
ToTensorV2()
])
def get_val_transforms(image_size=224):
return A.Compose([
A.Resize(height=image_size, width=image_size),
A.Normalize(mean=[0.5], std=[0.5], max_pixel_value=255.0),
ToTensorV2()
])
class PreprocessedDataset(torch.utils.data.Dataset):
def __init__(self, image_paths, labels, transforms=None, use_preprocessing=True, cache_pt=True, load_to_ram=False):
self.image_paths = image_paths
self.labels = labels
self.transforms = transforms
self.use_preprocessing = use_preprocessing
self.cache_pt = cache_pt
self.preprocessor = LungPreprocessor() if use_preprocessing else None
self.ram_cache = None
if load_to_ram:
self._preload_to_ram()
def _get_preprocessed(self, img_path, target_size=224):
if not self.use_preprocessing:
image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (target_size, target_size))
return image
cache_path = img_path.with_suffix('.pt.cache')
if self.cache_pt and cache_path.exists():
return torch.load(cache_path, weights_only=True).numpy()
image = self.preprocessor.preprocess(str(img_path), segment_lung=True)
if self.cache_pt:
torch.save(torch.from_numpy(image), cache_path)
return image
def _build_cache(self, paths):
processor = LungPreprocessor()
for p in paths:
try:
img = processor.preprocess(str(p), segment_lung=True)
torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache'))
except Exception:
pass
def _preload_to_ram(self):
n_total = len(self.image_paths)
print(f" Pre-loading {n_total} images into RAM...")
uncached = [p for p in self.image_paths if not p.with_suffix('.pt.cache').exists()]
if uncached:
n_workers = min(8, len(uncached))
chunk_size = (len(uncached) + n_workers - 1) // n_workers
chunks = [uncached[i:i+chunk_size] for i in range(0, len(uncached), chunk_size)]
print(f" Processing {len(uncached)} uncached images with {n_workers} workers...")
worker = LungPreprocessor()
for chunk in tqdm(chunks, desc="Caching"):
for p in chunk:
try:
img = worker.preprocess(str(p), segment_lung=True)
torch.save(torch.from_numpy(img), p.with_suffix('.pt.cache'))
except Exception:
pass
self.ram_cache = []
target_size = None
if self.transforms:
for t in self.transforms:
if hasattr(t, 'height'):
target_size = t.height
break
if target_size is None:
target_size = self.preprocessor.image_size if self.preprocessor else 224
for path in tqdm(self.image_paths, desc="Loading RAM"):
try:
image = torch.load(path.with_suffix('.pt.cache'), weights_only=True).numpy()
self.ram_cache.append(image)
except Exception as e:
print(f" Warning: Failed to load {path}: {e}")
self.ram_cache.append(np.zeros((target_size, target_size), dtype=np.uint8))
n = len(self.ram_cache)
size_gb = n * target_size * target_size * 1 / 1e9
print(f" Cached {n} images in RAM (~{size_gb:.1f} GB)")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
label = self.labels[idx]
if self.ram_cache is not None:
image = self.ram_cache[idx]
elif self.use_preprocessing:
image = self._get_preprocessed(self.image_paths[idx])
else:
image = cv2.imread(str(self.image_paths[idx]), cv2.IMREAD_GRAYSCALE)
target_size = self.transforms[0].height if self.transforms and hasattr(self.transforms[0], 'height') else 224
image = cv2.resize(image, (target_size, target_size))
if self.transforms:
augmented = self.transforms(image=image)
image = augmented['image']
else:
image = torch.from_numpy(image).unsqueeze(0).float() / 255.0
if image.shape[0] == 3:
image = image.mean(dim=0, keepdim=True)
elif image.dim() == 2:
image = image.unsqueeze(0)
return image, label