Spaces:
Running
Running
| """ | |
| Preprocessing module for FetalCLIP. | |
| Supports two pipelines: | |
| 1. DICOM (Full): US region extraction, fan isolation, text removal, denoising | |
| 2. Image (Basic): Square padding, resize | |
| """ | |
| import cv2 | |
| import copy | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Tuple, Dict, List, Optional | |
| from io import BytesIO | |
| # Try importing DICOM-specific libraries | |
| try: | |
| from pydicom import dcmread | |
| from pydicom.pixel_data_handlers import convert_color_space | |
| DICOM_AVAILABLE = True | |
| except ImportError: | |
| DICOM_AVAILABLE = False | |
| try: | |
| from skimage.restoration import denoise_nl_means, estimate_sigma | |
| SKIMAGE_AVAILABLE = True | |
| except ImportError: | |
| SKIMAGE_AVAILABLE = False | |
| try: | |
| import albumentations as A | |
| ALBUMENTATIONS_AVAILABLE = True | |
| except ImportError: | |
| ALBUMENTATIONS_AVAILABLE = False | |
| # ============================================================================ | |
| # CONSTANTS | |
| # ============================================================================ | |
| TARGET_SIZE = (512, 512) | |
| INTERPOLATION = cv2.INTER_LANCZOS4 | |
| INTENSITY_THRESHOLD = 0 | |
| SMALL_VIEW_MARGIN_CROP_Y = 1 | |
| YELLOW_BOX_BACKGROUND_PIXEL = np.array([57, 57, 57]) | |
| MIN_YELLOW_BOX_RECT_AREA = 2_000 | |
| MASK_INPAINTING_DILATE_KERNEL = np.ones((9, 9), np.uint8) | |
| DENOISE_NL_MEANS_PATCH_KW = dict( | |
| patch_size=7, | |
| patch_distance=6, | |
| channel_axis=-1, | |
| ) | |
| INPAINT_RADIUS = 5 | |
| # ============================================================================ | |
| # TEXT DETECTION UTILITIES (from utils_husain.py) | |
| # ============================================================================ | |
| def rgb2gray(rgb: np.ndarray) -> np.ndarray: | |
| """Convert RGB to grayscale while keeping 3 channels.""" | |
| r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2] | |
| gray = 0.299 * r + 0.5870 * g + 0.1140 * b | |
| rgb_grey = rgb.copy() | |
| rgb_grey[:, :, 0] = gray | |
| rgb_grey[:, :, 1] = gray | |
| rgb_grey[:, :, 2] = gray | |
| return rgb_grey | |
| def mask_filter(image: np.ndarray, grey_threshold: int) -> np.ndarray: | |
| """Create binary mask for pixels above threshold.""" | |
| img = image.copy() | |
| grey_img = rgb2gray(img) | |
| convert = np.zeros((img.shape[0], img.shape[1], 3)) | |
| idxs = np.where( | |
| (grey_img[:, :, 0] > grey_threshold) | |
| & (grey_img[:, :, 1] > grey_threshold) | |
| & (grey_img[:, :, 2] > grey_threshold) | |
| ) | |
| convert[idxs] = [255, 255, 255] | |
| return np.uint8(convert) | |
| def maximize_contrast(img_grayscale: np.ndarray) -> np.ndarray: | |
| """Enhance contrast using morphological operations.""" | |
| height, width = img_grayscale.shape | |
| structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) | |
| img_top_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_TOPHAT, structuring_element) | |
| img_black_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_BLACKHAT, structuring_element) | |
| img_plus_top_hat = cv2.add(img_grayscale, img_top_hat) | |
| result = cv2.subtract(img_plus_top_hat, img_black_hat) | |
| return result | |
| def detect_white_annotation(img: np.ndarray) -> np.ndarray: | |
| """Detect white text/annotations.""" | |
| img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| dif = maximize_contrast(img_gray) | |
| dif_rgb = cv2.cvtColor(dif, cv2.COLOR_GRAY2BGR) | |
| masked_img = mask_filter(dif_rgb, 254) | |
| dilation = cv2.dilate(masked_img, np.ones((3, 3), np.uint8), iterations=1) | |
| mask = cv2.cvtColor(dilation, cv2.COLOR_BGR2GRAY) | |
| return mask | |
| def detect_cyan(img: np.ndarray) -> np.ndarray: | |
| """Detect cyan colored text.""" | |
| image_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
| lowers = np.uint8([85, 150, 20]) | |
| uppers = np.uint8([95, 255, 255]) | |
| mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) | |
| return mask | |
| def detect_purple_text(img: np.ndarray) -> np.ndarray: | |
| """Detect purple colored text.""" | |
| image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
| lowers = np.uint8([110, 100, 50]) | |
| uppers = np.uint8([130, 255, 255]) | |
| mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) | |
| return mask | |
| def detect_orange_text(img: np.ndarray) -> np.ndarray: | |
| """Detect orange colored text.""" | |
| image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
| lowers = np.uint8([12, 150, 100]) | |
| uppers = np.uint8([27, 255, 255]) | |
| mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) | |
| return mask | |
| def detect_green_text(img: np.ndarray) -> np.ndarray: | |
| """Detect green colored text.""" | |
| image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
| lowers = np.uint8([50, 100, 50]) | |
| uppers = np.uint8([70, 255, 255]) | |
| mask = np.array(cv2.inRange(image_hsv, lowers, uppers)) | |
| return mask | |
| def detect_annotation(img: np.ndarray) -> np.ndarray: | |
| """Detect all text annotations (white, cyan, orange, purple, green).""" | |
| d1 = (detect_white_annotation(img) >= 127).astype(np.float32) | |
| d2 = (detect_cyan(img) >= 127).astype(np.float32) | |
| d3 = (detect_orange_text(img) >= 127).astype(np.float32) | |
| d4 = (detect_purple_text(img) >= 127).astype(np.float32) | |
| d5 = (detect_green_text(img) >= 127).astype(np.float32) | |
| inpaint_mask = d1 + d2 + d3 + d4 + d5 | |
| inpaint_mask = (inpaint_mask > 0).astype(np.uint8) * 255 | |
| inpaint_mask = maximize_contrast(inpaint_mask) | |
| blur = cv2.GaussianBlur(inpaint_mask, (5, 5), 0) | |
| ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| th3 = cv2.bitwise_or(th3, inpaint_mask) | |
| return th3 | |
| # ============================================================================ | |
| # DICOM UTILITIES (from utils_adam.py) | |
| # ============================================================================ | |
| def remove_text_box(im: np.ndarray, box_background_pixel: np.ndarray, min_rect_area: int = 2000) -> np.ndarray: | |
| """Remove yellow/gray text boxes from image.""" | |
| binary = np.all(im == box_background_pixel, axis=-1).astype(np.uint8) | |
| binary = binary * 255 | |
| contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| if len(contours) == 0: | |
| return im | |
| contour = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(contour) | |
| if w * h >= min_rect_area: | |
| im[y:y+h, x:x+w] = 0 | |
| return im | |
| def pad_to_square(im: np.ndarray) -> np.ndarray: | |
| """Pad image to square using black padding.""" | |
| if ALBUMENTATIONS_AVAILABLE: | |
| target_size = max(im.shape[:2]) | |
| return A.PadIfNeeded(min_height=target_size, min_width=target_size, | |
| border_mode=0, value=(0, 0, 0))(image=im)["image"] | |
| else: | |
| # Fallback without albumentations | |
| height, width = im.shape[:2] | |
| max_side = max(height, width) | |
| if len(im.shape) == 3: | |
| result = np.zeros((max_side, max_side, im.shape[2]), dtype=im.dtype) | |
| else: | |
| result = np.zeros((max_side, max_side), dtype=im.dtype) | |
| y_offset = (max_side - height) // 2 | |
| x_offset = (max_side - width) // 2 | |
| result[y_offset:y_offset+height, x_offset:x_offset+width] = im | |
| return result | |
| def get_fan_region(im: np.ndarray, threshold: int = 1) -> np.ndarray: | |
| """Extract the ultrasound fan/cone region.""" | |
| imgray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) | |
| ret, thresh = cv2.threshold(imgray, threshold, 255, 0) | |
| contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| if len(contours) == 0: | |
| return im | |
| contour = max(contours, key=cv2.contourArea) | |
| # Create mask | |
| filled_image = np.zeros_like(im) | |
| cv2.drawContours(filled_image, [contour], -1, (255, 255, 255), thickness=cv2.FILLED) | |
| # Crop to bounding box | |
| x, y, w, h = cv2.boundingRect(contour) | |
| cropped_image = im[y:y+h, x:x+w] | |
| filled_image = filled_image[y:y+h, x:x+w] | |
| # Apply mask | |
| masked_image = cv2.bitwise_and(cropped_image, filled_image) | |
| return masked_image | |
| def get_us_region_from_dcm(us, sv_mc_y: int = 1) -> np.ndarray: | |
| """Extract ultrasound region from DICOM using metadata.""" | |
| # Initialize default coordinates | |
| x0_f, x1_f, y0_f, y1_f = None, None, None, None | |
| x0, x1, y0, y1 = 0, us.pixel_array.shape[1], 0, us.pixel_array.shape[0] | |
| # Check for ultrasound regions metadata | |
| if hasattr(us, 'SequenceOfUltrasoundRegions') and len(us.SequenceOfUltrasoundRegions) > 0: | |
| regions = us.SequenceOfUltrasoundRegions | |
| if len(regions) == 2 and int(regions[0].RegionDataType) == 1 and int(regions[1].RegionDataType) == 1: | |
| # Image with small view (picture-in-picture) | |
| x0_f = np.min([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0]) | |
| x1_f = np.max([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0]) | |
| y0_f = np.max([regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0]) | |
| y1_f = np.max([regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1]) | |
| x0 = min(regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0) | |
| x1 = max(regions[0].RegionLocationMaxX1, regions[1].RegionLocationMaxX1) | |
| y0 = min(regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0) | |
| y1 = max(regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1) | |
| elif len(regions) >= 1 and int(regions[0].RegionDataType) == 1: | |
| x0 = regions[0].RegionLocationMinX0 | |
| x1 = regions[0].RegionLocationMaxX1 | |
| y0 = regions[0].RegionLocationMinY0 | |
| y1 = regions[0].RegionLocationMaxY1 | |
| ds = copy.deepcopy(us.pixel_array) | |
| # Handle color space conversion | |
| if hasattr(us, 'PhotometricInterpretation'): | |
| if 'ybr_full' in us.PhotometricInterpretation.lower(): | |
| ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True) | |
| # Remove small view if present | |
| if x0_f is not None: | |
| ds[y0_f - sv_mc_y:y1_f, x0_f:x1_f, :] = 0 | |
| # Crop to ultrasound region | |
| ds = ds[y0:y1, x0:x1, :] | |
| return ds | |
| # ============================================================================ | |
| # MAIN PREPROCESSING FUNCTIONS | |
| # ============================================================================ | |
| def preprocess_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict]: | |
| """ | |
| Full DICOM preprocessing pipeline. | |
| Steps: | |
| 1. Parse DICOM file | |
| 2. Extract ultrasound region from metadata | |
| 3. Remove yellow text boxes | |
| 4. Extract fan/cone region | |
| 5. Detect text annotations | |
| 6. Inpaint to remove text | |
| 7. Denoise using non-local means | |
| 8. Pad to square | |
| 9. Resize to target size | |
| Returns: | |
| Tuple of (PIL Image, metadata dict) | |
| """ | |
| if not DICOM_AVAILABLE: | |
| raise RuntimeError("pydicom not installed. Install with: pip install pydicom") | |
| # Parse DICOM | |
| us = dcmread(BytesIO(file_bytes)) | |
| # Extract ultrasound region | |
| ds = get_us_region_from_dcm(us, sv_mc_y=SMALL_VIEW_MARGIN_CROP_Y) | |
| # Remove text box | |
| img = remove_text_box(ds.copy(), box_background_pixel=YELLOW_BOX_BACKGROUND_PIXEL, | |
| min_rect_area=MIN_YELLOW_BOX_RECT_AREA) | |
| # Extract fan region | |
| fan = get_fan_region(img, threshold=INTENSITY_THRESHOLD) | |
| # Detect annotations | |
| image_grey = fan.copy() | |
| mask_inpaint = detect_annotation(fan) | |
| mask_inpaint = cv2.dilate(mask_inpaint, MASK_INPAINTING_DILATE_KERNEL) | |
| # Inpaint to remove text | |
| dst = cv2.inpaint(image_grey, mask_inpaint, INPAINT_RADIUS, cv2.INPAINT_TELEA) | |
| dst = dst / np.max(dst) if np.max(dst) > 0 else dst | |
| # Denoise | |
| if SKIMAGE_AVAILABLE: | |
| sigma = estimate_sigma(dst, channel_axis=-1, average_sigmas=True) | |
| median = denoise_nl_means(dst, h=0.8 * sigma, fast_mode=True, **DENOISE_NL_MEANS_PATCH_KW) | |
| median = np.clip(median * 255, 0, 255).astype(np.uint8) | |
| else: | |
| median = np.clip(dst * 255, 0, 255).astype(np.uint8) | |
| # Pad to square (model will resize to 224×224) | |
| img = pad_to_square(median) | |
| padded_size = max(img.shape[:2]) | |
| # Extract metadata | |
| try: | |
| rows = getattr(us, 'Rows', fan.shape[0]) | |
| columns = getattr(us, 'Columns', fan.shape[1]) | |
| if hasattr(us, 'PixelSpacing') and us.PixelSpacing is not None: | |
| orig_pixel_spacing = [float(sp) for sp in us.PixelSpacing] | |
| else: | |
| orig_pixel_spacing = [1.0, 1.0] | |
| except Exception: | |
| rows = fan.shape[0] | |
| columns = fan.shape[1] | |
| orig_pixel_spacing = [1.0, 1.0] | |
| metadata = { | |
| 'original_size': (rows, columns), | |
| 'original_pixel_spacing': orig_pixel_spacing, | |
| 'fan_size': (fan.shape[0], fan.shape[1]), | |
| 'padded_size': padded_size, | |
| 'pixel_spacing': orig_pixel_spacing[0], # Original spacing, model.py adjusts for resize | |
| 'processed_size': (padded_size, padded_size), | |
| } | |
| # Convert to PIL | |
| if len(img.shape) == 2: | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| else: | |
| img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img_pil = Image.fromarray(img_rgb) | |
| steps_applied = [ | |
| "dicom_parsing", | |
| "us_region_extraction", | |
| "text_box_removal", | |
| "fan_extraction", | |
| "annotation_detection", | |
| "inpainting", | |
| "denoising" if SKIMAGE_AVAILABLE else "normalization", | |
| "square_padding", | |
| ] | |
| return img_pil, { | |
| "type": "dicom", | |
| "pipeline": "full", | |
| "steps_applied": steps_applied, | |
| "metadata": metadata | |
| } | |
| def preprocess_image(image: Image.Image) -> Tuple[Image.Image, Dict]: | |
| """ | |
| Basic image preprocessing pipeline. | |
| Steps: | |
| 1. Convert to RGB if needed | |
| 2. Pad to square | |
| 3. (Model will resize to 224) | |
| Returns: | |
| Tuple of (PIL Image, preprocessing info dict) | |
| """ | |
| # Convert to RGB | |
| if image.mode not in ('RGB', 'L'): | |
| image = image.convert('RGB') | |
| width, height = image.size | |
| max_side = max(width, height) | |
| # Create square image with black padding | |
| padding_color = (0, 0, 0) if image.mode == "RGB" else 0 | |
| new_image = Image.new(image.mode, (max_side, max_side), padding_color) | |
| # Center the original | |
| padding_left = (max_side - width) // 2 | |
| padding_top = (max_side - height) // 2 | |
| new_image.paste(image, (padding_left, padding_top)) | |
| # Ensure RGB | |
| if new_image.mode == 'L': | |
| new_image = new_image.convert('RGB') | |
| steps_applied = [ | |
| "rgb_conversion", | |
| "square_padding", | |
| ] | |
| return new_image, { | |
| "type": "image", | |
| "pipeline": "basic", | |
| "steps_applied": steps_applied, | |
| "metadata": { | |
| "original_size": (height, width), | |
| "processed_size": (max_side, max_side), | |
| } | |
| } | |
| def is_dicom_file(file_bytes: bytes, filename: str) -> bool: | |
| """Check if file is a DICOM file.""" | |
| # Check by extension | |
| lower_name = filename.lower() | |
| if lower_name.endswith('.dcm') or lower_name.endswith('.dicom'): | |
| return True | |
| # Check DICOM magic number (DICM at offset 128) | |
| if len(file_bytes) > 132: | |
| if file_bytes[128:132] == b'DICM': | |
| return True | |
| return False | |
| def image_to_base64(image: Image.Image) -> str: | |
| """Convert PIL Image to base64 string.""" | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| import base64 | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| def get_dicom_preview(file_bytes: bytes) -> str: | |
| """Extract raw image from DICOM for preview (no preprocessing).""" | |
| if not DICOM_AVAILABLE: | |
| raise RuntimeError("pydicom not installed") | |
| us = dcmread(BytesIO(file_bytes)) | |
| ds = us.pixel_array | |
| # Handle color space | |
| if hasattr(us, 'PhotometricInterpretation'): | |
| if 'ybr_full' in us.PhotometricInterpretation.lower(): | |
| ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True) | |
| # Handle video (take first frame) | |
| if len(ds.shape) == 4: | |
| ds = ds[0] | |
| # Normalize to 0-255 | |
| if ds.max() > 255: | |
| ds = ((ds - ds.min()) / (ds.max() - ds.min()) * 255).astype(np.uint8) | |
| # Convert to RGB if grayscale | |
| if len(ds.shape) == 2: | |
| ds = cv2.cvtColor(ds, cv2.COLOR_GRAY2RGB) | |
| elif ds.shape[2] == 3: | |
| ds = cv2.cvtColor(ds, cv2.COLOR_BGR2RGB) | |
| img_pil = Image.fromarray(ds) | |
| return image_to_base64(img_pil) | |
| def preprocess_file(file_bytes: bytes, filename: str) -> Tuple[Image.Image, Dict]: | |
| """ | |
| Automatically detect file type and apply appropriate preprocessing. | |
| Returns: | |
| Tuple of (PIL Image, preprocessing info dict with base64 image) | |
| """ | |
| if is_dicom_file(file_bytes, filename): | |
| processed_image, info = preprocess_dicom(file_bytes) | |
| # Add base64 encoded image for frontend display | |
| info["processed_image_base64"] = image_to_base64(processed_image) | |
| return processed_image, info | |
| else: | |
| # Regular image | |
| image = Image.open(BytesIO(file_bytes)) | |
| processed_image, info = preprocess_image(image) | |
| info["processed_image_base64"] = image_to_base64(processed_image) | |
| return processed_image, info | |