""" Preprocessing module for FetalCLIP. Supports two pipelines: 1. DICOM (Full): US region extraction, fan isolation, text removal, denoising 2. Image (Basic): Square padding, resize """ import cv2 import copy import numpy as np from PIL import Image from typing import Tuple, Dict, List, Optional from io import BytesIO # Try importing DICOM-specific libraries try: from pydicom import dcmread from pydicom.pixel_data_handlers import convert_color_space DICOM_AVAILABLE = True except ImportError: DICOM_AVAILABLE = False try: from skimage.restoration import denoise_nl_means, estimate_sigma SKIMAGE_AVAILABLE = True except ImportError: SKIMAGE_AVAILABLE = False try: import albumentations as A ALBUMENTATIONS_AVAILABLE = True except ImportError: ALBUMENTATIONS_AVAILABLE = False # ============================================================================ # CONSTANTS # ============================================================================ TARGET_SIZE = (512, 512) INTERPOLATION = cv2.INTER_LANCZOS4 INTENSITY_THRESHOLD = 0 SMALL_VIEW_MARGIN_CROP_Y = 1 YELLOW_BOX_BACKGROUND_PIXEL = np.array([57, 57, 57]) MIN_YELLOW_BOX_RECT_AREA = 2_000 MASK_INPAINTING_DILATE_KERNEL = np.ones((9, 9), np.uint8) DENOISE_NL_MEANS_PATCH_KW = dict( patch_size=7, patch_distance=6, channel_axis=-1, ) INPAINT_RADIUS = 5 # ============================================================================ # TEXT DETECTION UTILITIES (from utils_husain.py) # ============================================================================ def rgb2gray(rgb: np.ndarray) -> np.ndarray: """Convert RGB to grayscale while keeping 3 channels.""" r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] gray = 0.299 * r + 0.5870 * g + 0.1140 * b rgb_grey = rgb.copy() rgb_grey[:, :, 0] = gray rgb_grey[:, :, 1] = gray rgb_grey[:, :, 2] = gray return rgb_grey def mask_filter(image: np.ndarray, grey_threshold: int) -> np.ndarray: """Create binary mask for pixels above threshold.""" img = image.copy() grey_img = rgb2gray(img) convert = np.zeros((img.shape[0], img.shape[1], 3)) idxs = np.where( (grey_img[:, :, 0] > grey_threshold) & (grey_img[:, :, 1] > grey_threshold) & (grey_img[:, :, 2] > grey_threshold) ) convert[idxs] = [255, 255, 255] return np.uint8(convert) def maximize_contrast(img_grayscale: np.ndarray) -> np.ndarray: """Enhance contrast using morphological operations.""" height, width = img_grayscale.shape structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) img_top_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_TOPHAT, structuring_element) img_black_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_BLACKHAT, structuring_element) img_plus_top_hat = cv2.add(img_grayscale, img_top_hat) result = cv2.subtract(img_plus_top_hat, img_black_hat) return result def detect_white_annotation(img: np.ndarray) -> np.ndarray: """Detect white text/annotations.""" img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) dif = maximize_contrast(img_gray) dif_rgb = cv2.cvtColor(dif, cv2.COLOR_GRAY2BGR) masked_img = mask_filter(dif_rgb, 254) dilation = cv2.dilate(masked_img, np.ones((3, 3), np.uint8), iterations=1) mask = cv2.cvtColor(dilation, cv2.COLOR_BGR2GRAY) return mask def detect_cyan(img: np.ndarray) -> np.ndarray: """Detect cyan colored text.""" image_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) lowers = np.uint8([85, 150, 20]) uppers = np.uint8([95, 255, 255]) mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) return mask def detect_purple_text(img: np.ndarray) -> np.ndarray: """Detect purple colored text.""" image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) lowers = np.uint8([110, 100, 50]) uppers = np.uint8([130, 255, 255]) mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) return mask def detect_orange_text(img: np.ndarray) -> np.ndarray: """Detect orange colored text.""" image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) lowers = np.uint8([12, 150, 100]) uppers = np.uint8([27, 255, 255]) mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) return mask def detect_green_text(img: np.ndarray) -> np.ndarray: """Detect green colored text.""" image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) lowers = np.uint8([50, 100, 50]) uppers = np.uint8([70, 255, 255]) mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) return mask def detect_annotation(img: np.ndarray) -> np.ndarray: """Detect all text annotations (white, cyan, orange, purple, green).""" d1 = (detect_white_annotation(img) >= 127).astype(np.float32) d2 = (detect_cyan(img) >= 127).astype(np.float32) d3 = (detect_orange_text(img) >= 127).astype(np.float32) d4 = (detect_purple_text(img) >= 127).astype(np.float32) d5 = (detect_green_text(img) >= 127).astype(np.float32) inpaint_mask = d1 + d2 + d3 + d4 + d5 inpaint_mask = (inpaint_mask > 0).astype(np.uint8) * 255 inpaint_mask = maximize_contrast(inpaint_mask) blur = cv2.GaussianBlur(inpaint_mask, (5, 5), 0) ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) th3 = cv2.bitwise_or(th3, inpaint_mask) return th3 # ============================================================================ # DICOM UTILITIES (from utils_adam.py) # ============================================================================ def remove_text_box(im: np.ndarray, box_background_pixel: np.ndarray, min_rect_area: int = 2000) -> np.ndarray: """Remove yellow/gray text boxes from image.""" binary = np.all(im == box_background_pixel, axis=-1).astype(np.uint8) binary = binary * 255 contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) == 0: return im contour = max(contours, key=cv2.contourArea) x, y, w, h = cv2.boundingRect(contour) if w * h >= min_rect_area: im[y:y+h, x:x+w] = 0 return im def pad_to_square(im: np.ndarray) -> np.ndarray: """Pad image to square using black padding.""" if ALBUMENTATIONS_AVAILABLE: target_size = max(im.shape[:2]) return A.PadIfNeeded(min_height=target_size, min_width=target_size, border_mode=0, value=(0, 0, 0))(image=im)["image"] else: # Fallback without albumentations height, width = im.shape[:2] max_side = max(height, width) if len(im.shape) == 3: result = np.zeros((max_side, max_side, im.shape[2]), dtype=im.dtype) else: result = np.zeros((max_side, max_side), dtype=im.dtype) y_offset = (max_side - height) // 2 x_offset = (max_side - width) // 2 result[y_offset:y_offset+height, x_offset:x_offset+width] = im return result def get_fan_region(im: np.ndarray, threshold: int = 1) -> np.ndarray: """Extract the ultrasound fan/cone region.""" imgray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) ret, thresh = cv2.threshold(imgray, threshold, 255, 0) contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) if len(contours) == 0: return im contour = max(contours, key=cv2.contourArea) # Create mask filled_image = np.zeros_like(im) cv2.drawContours(filled_image, [contour], -1, (255, 255, 255), thickness=cv2.FILLED) # Crop to bounding box x, y, w, h = cv2.boundingRect(contour) cropped_image = im[y:y+h, x:x+w] filled_image = filled_image[y:y+h, x:x+w] # Apply mask masked_image = cv2.bitwise_and(cropped_image, filled_image) return masked_image def get_us_region_from_dcm(us, sv_mc_y: int = 1) -> np.ndarray: """Extract ultrasound region from DICOM using metadata.""" # Initialize default coordinates x0_f, x1_f, y0_f, y1_f = None, None, None, None x0, x1, y0, y1 = 0, us.pixel_array.shape[1], 0, us.pixel_array.shape[0] # Check for ultrasound regions metadata if hasattr(us, 'SequenceOfUltrasoundRegions') and len(us.SequenceOfUltrasoundRegions) > 0: regions = us.SequenceOfUltrasoundRegions if len(regions) == 2 and int(regions[0].RegionDataType) == 1 and int(regions[1].RegionDataType) == 1: # Image with small view (picture-in-picture) x0_f = np.min([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0]) x1_f = np.max([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0]) y0_f = np.max([regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0]) y1_f = np.max([regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1]) x0 = min(regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0) x1 = max(regions[0].RegionLocationMaxX1, regions[1].RegionLocationMaxX1) y0 = min(regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0) y1 = max(regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1) elif len(regions) >= 1 and int(regions[0].RegionDataType) == 1: x0 = regions[0].RegionLocationMinX0 x1 = regions[0].RegionLocationMaxX1 y0 = regions[0].RegionLocationMinY0 y1 = regions[0].RegionLocationMaxY1 ds = copy.deepcopy(us.pixel_array) # Handle color space conversion if hasattr(us, 'PhotometricInterpretation'): if 'ybr_full' in us.PhotometricInterpretation.lower(): ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True) # Remove small view if present if x0_f is not None: ds[y0_f - sv_mc_y:y1_f, x0_f:x1_f, :] = 0 # Crop to ultrasound region ds = ds[y0:y1, x0:x1, :] return ds # ============================================================================ # MAIN PREPROCESSING FUNCTIONS # ============================================================================ def preprocess_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict]: """ Full DICOM preprocessing pipeline. Steps: 1. Parse DICOM file 2. Extract ultrasound region from metadata 3. Remove yellow text boxes 4. Extract fan/cone region 5. Detect text annotations 6. Inpaint to remove text 7. Denoise using non-local means 8. Pad to square 9. Resize to target size Returns: Tuple of (PIL Image, metadata dict) """ if not DICOM_AVAILABLE: raise RuntimeError("pydicom not installed. Install with: pip install pydicom") # Parse DICOM us = dcmread(BytesIO(file_bytes)) # Extract ultrasound region ds = get_us_region_from_dcm(us, sv_mc_y=SMALL_VIEW_MARGIN_CROP_Y) # Remove text box img = remove_text_box(ds.copy(), box_background_pixel=YELLOW_BOX_BACKGROUND_PIXEL, min_rect_area=MIN_YELLOW_BOX_RECT_AREA) # Extract fan region fan = get_fan_region(img, threshold=INTENSITY_THRESHOLD) # Detect annotations image_grey = fan.copy() mask_inpaint = detect_annotation(fan) mask_inpaint = cv2.dilate(mask_inpaint, MASK_INPAINTING_DILATE_KERNEL) # Inpaint to remove text dst = cv2.inpaint(image_grey, mask_inpaint, INPAINT_RADIUS, cv2.INPAINT_TELEA) dst = dst / np.max(dst) if np.max(dst) > 0 else dst # Denoise if SKIMAGE_AVAILABLE: sigma = estimate_sigma(dst, channel_axis=-1, average_sigmas=True) median = denoise_nl_means(dst, h=0.8 * sigma, fast_mode=True, **DENOISE_NL_MEANS_PATCH_KW) median = np.clip(median * 255, 0, 255).astype(np.uint8) else: median = np.clip(dst * 255, 0, 255).astype(np.uint8) # Pad to square (model will resize to 224×224) img = pad_to_square(median) padded_size = max(img.shape[:2]) # Extract metadata try: rows = getattr(us, 'Rows', fan.shape[0]) columns = getattr(us, 'Columns', fan.shape[1]) if hasattr(us, 'PixelSpacing') and us.PixelSpacing is not None: orig_pixel_spacing = [float(sp) for sp in us.PixelSpacing] else: orig_pixel_spacing = [1.0, 1.0] except Exception: rows = fan.shape[0] columns = fan.shape[1] orig_pixel_spacing = [1.0, 1.0] metadata = { 'original_size': (rows, columns), 'original_pixel_spacing': orig_pixel_spacing, 'fan_size': (fan.shape[0], fan.shape[1]), 'padded_size': padded_size, 'pixel_spacing': orig_pixel_spacing[0], # Original spacing, model.py adjusts for resize 'processed_size': (padded_size, padded_size), } # Convert to PIL if len(img.shape) == 2: img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) else: img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_pil = Image.fromarray(img_rgb) steps_applied = [ "dicom_parsing", "us_region_extraction", "text_box_removal", "fan_extraction", "annotation_detection", "inpainting", "denoising" if SKIMAGE_AVAILABLE else "normalization", "square_padding", ] return img_pil, { "type": "dicom", "pipeline": "full", "steps_applied": steps_applied, "metadata": metadata } def preprocess_image(image: Image.Image) -> Tuple[Image.Image, Dict]: """ Basic image preprocessing pipeline. Steps: 1. Convert to RGB if needed 2. Pad to square 3. (Model will resize to 224) Returns: Tuple of (PIL Image, preprocessing info dict) """ # Convert to RGB if image.mode not in ('RGB', 'L'): image = image.convert('RGB') width, height = image.size max_side = max(width, height) # Create square image with black padding padding_color = (0, 0, 0) if image.mode == "RGB" else 0 new_image = Image.new(image.mode, (max_side, max_side), padding_color) # Center the original padding_left = (max_side - width) // 2 padding_top = (max_side - height) // 2 new_image.paste(image, (padding_left, padding_top)) # Ensure RGB if new_image.mode == 'L': new_image = new_image.convert('RGB') steps_applied = [ "rgb_conversion", "square_padding", ] return new_image, { "type": "image", "pipeline": "basic", "steps_applied": steps_applied, "metadata": { "original_size": (height, width), "processed_size": (max_side, max_side), } } def is_dicom_file(file_bytes: bytes, filename: str) -> bool: """Check if file is a DICOM file.""" # Check by extension lower_name = filename.lower() if lower_name.endswith('.dcm') or lower_name.endswith('.dicom'): return True # Check DICOM magic number (DICM at offset 128) if len(file_bytes) > 132: if file_bytes[128:132] == b'DICM': return True return False def image_to_base64(image: Image.Image) -> str: """Convert PIL Image to base64 string.""" buffered = BytesIO() image.save(buffered, format="PNG") import base64 return base64.b64encode(buffered.getvalue()).decode('utf-8') def get_dicom_preview(file_bytes: bytes) -> str: """Extract raw image from DICOM for preview (no preprocessing).""" if not DICOM_AVAILABLE: raise RuntimeError("pydicom not installed") us = dcmread(BytesIO(file_bytes)) ds = us.pixel_array # Handle color space if hasattr(us, 'PhotometricInterpretation'): if 'ybr_full' in us.PhotometricInterpretation.lower(): ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True) # Handle video (take first frame) if len(ds.shape) == 4: ds = ds[0] # Normalize to 0-255 if ds.max() > 255: ds = ((ds - ds.min()) / (ds.max() - ds.min()) * 255).astype(np.uint8) # Convert to RGB if grayscale if len(ds.shape) == 2: ds = cv2.cvtColor(ds, cv2.COLOR_GRAY2RGB) elif ds.shape[2] == 3: ds = cv2.cvtColor(ds, cv2.COLOR_BGR2RGB) img_pil = Image.fromarray(ds) return image_to_base64(img_pil) def preprocess_file(file_bytes: bytes, filename: str) -> Tuple[Image.Image, Dict]: """ Automatically detect file type and apply appropriate preprocessing. Returns: Tuple of (PIL Image, preprocessing info dict with base64 image) """ if is_dicom_file(file_bytes, filename): processed_image, info = preprocess_dicom(file_bytes) # Add base64 encoded image for frontend display info["processed_image_base64"] = image_to_base64(processed_image) return processed_image, info else: # Regular image image = Image.open(BytesIO(file_bytes)) processed_image, info = preprocess_image(image) info["processed_image_base64"] = image_to_base64(processed_image) return processed_image, info