FetalCLIP / backend /app /services /preprocessing.py
Numan Saeed
View-aware GA with WHO biometry formulas
cbd23a5
"""
Preprocessing module for FetalCLIP.
Supports two pipelines:
1. DICOM (Full): US region extraction, fan isolation, text removal, denoising
2. Image (Basic): Square padding, resize
"""
import cv2
import copy
import numpy as np
from PIL import Image
from typing import Tuple, Dict, List, Optional
from io import BytesIO
# Try importing DICOM-specific libraries
try:
from pydicom import dcmread
from pydicom.pixel_data_handlers import convert_color_space
DICOM_AVAILABLE = True
except ImportError:
DICOM_AVAILABLE = False
try:
from skimage.restoration import denoise_nl_means, estimate_sigma
SKIMAGE_AVAILABLE = True
except ImportError:
SKIMAGE_AVAILABLE = False
try:
import albumentations as A
ALBUMENTATIONS_AVAILABLE = True
except ImportError:
ALBUMENTATIONS_AVAILABLE = False
# ============================================================================
# CONSTANTS
# ============================================================================
TARGET_SIZE = (512, 512)
INTERPOLATION = cv2.INTER_LANCZOS4
INTENSITY_THRESHOLD = 0
SMALL_VIEW_MARGIN_CROP_Y = 1
YELLOW_BOX_BACKGROUND_PIXEL = np.array([57, 57, 57])
MIN_YELLOW_BOX_RECT_AREA = 2_000
MASK_INPAINTING_DILATE_KERNEL = np.ones((9, 9), np.uint8)
DENOISE_NL_MEANS_PATCH_KW = dict(
patch_size=7,
patch_distance=6,
channel_axis=-1,
)
INPAINT_RADIUS = 5
# ============================================================================
# TEXT DETECTION UTILITIES (from utils_husain.py)
# ============================================================================
def rgb2gray(rgb: np.ndarray) -> np.ndarray:
"""Convert RGB to grayscale while keeping 3 channels."""
r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
gray = 0.299 * r + 0.5870 * g + 0.1140 * b
rgb_grey = rgb.copy()
rgb_grey[:, :, 0] = gray
rgb_grey[:, :, 1] = gray
rgb_grey[:, :, 2] = gray
return rgb_grey
def mask_filter(image: np.ndarray, grey_threshold: int) -> np.ndarray:
"""Create binary mask for pixels above threshold."""
img = image.copy()
grey_img = rgb2gray(img)
convert = np.zeros((img.shape[0], img.shape[1], 3))
idxs = np.where(
(grey_img[:, :, 0] > grey_threshold)
& (grey_img[:, :, 1] > grey_threshold)
& (grey_img[:, :, 2] > grey_threshold)
)
convert[idxs] = [255, 255, 255]
return np.uint8(convert)
def maximize_contrast(img_grayscale: np.ndarray) -> np.ndarray:
"""Enhance contrast using morphological operations."""
height, width = img_grayscale.shape
structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
img_top_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_TOPHAT, structuring_element)
img_black_hat = cv2.morphologyEx(img_grayscale, cv2.MORPH_BLACKHAT, structuring_element)
img_plus_top_hat = cv2.add(img_grayscale, img_top_hat)
result = cv2.subtract(img_plus_top_hat, img_black_hat)
return result
def detect_white_annotation(img: np.ndarray) -> np.ndarray:
"""Detect white text/annotations."""
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
dif = maximize_contrast(img_gray)
dif_rgb = cv2.cvtColor(dif, cv2.COLOR_GRAY2BGR)
masked_img = mask_filter(dif_rgb, 254)
dilation = cv2.dilate(masked_img, np.ones((3, 3), np.uint8), iterations=1)
mask = cv2.cvtColor(dilation, cv2.COLOR_BGR2GRAY)
return mask
def detect_cyan(img: np.ndarray) -> np.ndarray:
"""Detect cyan colored text."""
image_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
lowers = np.uint8([85, 150, 20])
uppers = np.uint8([95, 255, 255])
mask = np.array(cv2.inRange(image_hsv, lowers, uppers))
return mask
def detect_purple_text(img: np.ndarray) -> np.ndarray:
"""Detect purple colored text."""
image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
lowers = np.uint8([110, 100, 50])
uppers = np.uint8([130, 255, 255])
mask = np.array(cv2.inRange(image_hsv, lowers, uppers))
return mask
def detect_orange_text(img: np.ndarray) -> np.ndarray:
"""Detect orange colored text."""
image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
lowers = np.uint8([12, 150, 100])
uppers = np.uint8([27, 255, 255])
mask = np.array(cv2.inRange(image_hsv, lowers, uppers))
return mask
def detect_green_text(img: np.ndarray) -> np.ndarray:
"""Detect green colored text."""
image_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
lowers = np.uint8([50, 100, 50])
uppers = np.uint8([70, 255, 255])
mask = np.array(cv2.inRange(image_hsv, lowers, uppers))
return mask
def detect_annotation(img: np.ndarray) -> np.ndarray:
"""Detect all text annotations (white, cyan, orange, purple, green)."""
d1 = (detect_white_annotation(img) >= 127).astype(np.float32)
d2 = (detect_cyan(img) >= 127).astype(np.float32)
d3 = (detect_orange_text(img) >= 127).astype(np.float32)
d4 = (detect_purple_text(img) >= 127).astype(np.float32)
d5 = (detect_green_text(img) >= 127).astype(np.float32)
inpaint_mask = d1 + d2 + d3 + d4 + d5
inpaint_mask = (inpaint_mask > 0).astype(np.uint8) * 255
inpaint_mask = maximize_contrast(inpaint_mask)
blur = cv2.GaussianBlur(inpaint_mask, (5, 5), 0)
ret3, th3 = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
th3 = cv2.bitwise_or(th3, inpaint_mask)
return th3
# ============================================================================
# DICOM UTILITIES (from utils_adam.py)
# ============================================================================
def remove_text_box(im: np.ndarray, box_background_pixel: np.ndarray, min_rect_area: int = 2000) -> np.ndarray:
"""Remove yellow/gray text boxes from image."""
binary = np.all(im == box_background_pixel, axis=-1).astype(np.uint8)
binary = binary * 255
contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
return im
contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(contour)
if w * h >= min_rect_area:
im[y:y+h, x:x+w] = 0
return im
def pad_to_square(im: np.ndarray) -> np.ndarray:
"""Pad image to square using black padding."""
if ALBUMENTATIONS_AVAILABLE:
target_size = max(im.shape[:2])
return A.PadIfNeeded(min_height=target_size, min_width=target_size,
border_mode=0, value=(0, 0, 0))(image=im)["image"]
else:
# Fallback without albumentations
height, width = im.shape[:2]
max_side = max(height, width)
if len(im.shape) == 3:
result = np.zeros((max_side, max_side, im.shape[2]), dtype=im.dtype)
else:
result = np.zeros((max_side, max_side), dtype=im.dtype)
y_offset = (max_side - height) // 2
x_offset = (max_side - width) // 2
result[y_offset:y_offset+height, x_offset:x_offset+width] = im
return result
def get_fan_region(im: np.ndarray, threshold: int = 1) -> np.ndarray:
"""Extract the ultrasound fan/cone region."""
imgray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
ret, thresh = cv2.threshold(imgray, threshold, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
if len(contours) == 0:
return im
contour = max(contours, key=cv2.contourArea)
# Create mask
filled_image = np.zeros_like(im)
cv2.drawContours(filled_image, [contour], -1, (255, 255, 255), thickness=cv2.FILLED)
# Crop to bounding box
x, y, w, h = cv2.boundingRect(contour)
cropped_image = im[y:y+h, x:x+w]
filled_image = filled_image[y:y+h, x:x+w]
# Apply mask
masked_image = cv2.bitwise_and(cropped_image, filled_image)
return masked_image
def get_us_region_from_dcm(us, sv_mc_y: int = 1) -> np.ndarray:
"""Extract ultrasound region from DICOM using metadata."""
# Initialize default coordinates
x0_f, x1_f, y0_f, y1_f = None, None, None, None
x0, x1, y0, y1 = 0, us.pixel_array.shape[1], 0, us.pixel_array.shape[0]
# Check for ultrasound regions metadata
if hasattr(us, 'SequenceOfUltrasoundRegions') and len(us.SequenceOfUltrasoundRegions) > 0:
regions = us.SequenceOfUltrasoundRegions
if len(regions) == 2 and int(regions[0].RegionDataType) == 1 and int(regions[1].RegionDataType) == 1:
# Image with small view (picture-in-picture)
x0_f = np.min([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0])
x1_f = np.max([regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0])
y0_f = np.max([regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0])
y1_f = np.max([regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1])
x0 = min(regions[0].RegionLocationMinX0, regions[1].RegionLocationMinX0)
x1 = max(regions[0].RegionLocationMaxX1, regions[1].RegionLocationMaxX1)
y0 = min(regions[0].RegionLocationMinY0, regions[1].RegionLocationMinY0)
y1 = max(regions[0].RegionLocationMaxY1, regions[1].RegionLocationMaxY1)
elif len(regions) >= 1 and int(regions[0].RegionDataType) == 1:
x0 = regions[0].RegionLocationMinX0
x1 = regions[0].RegionLocationMaxX1
y0 = regions[0].RegionLocationMinY0
y1 = regions[0].RegionLocationMaxY1
ds = copy.deepcopy(us.pixel_array)
# Handle color space conversion
if hasattr(us, 'PhotometricInterpretation'):
if 'ybr_full' in us.PhotometricInterpretation.lower():
ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True)
# Remove small view if present
if x0_f is not None:
ds[y0_f - sv_mc_y:y1_f, x0_f:x1_f, :] = 0
# Crop to ultrasound region
ds = ds[y0:y1, x0:x1, :]
return ds
# ============================================================================
# MAIN PREPROCESSING FUNCTIONS
# ============================================================================
def preprocess_dicom(file_bytes: bytes) -> Tuple[Image.Image, Dict]:
"""
Full DICOM preprocessing pipeline.
Steps:
1. Parse DICOM file
2. Extract ultrasound region from metadata
3. Remove yellow text boxes
4. Extract fan/cone region
5. Detect text annotations
6. Inpaint to remove text
7. Denoise using non-local means
8. Pad to square
9. Resize to target size
Returns:
Tuple of (PIL Image, metadata dict)
"""
if not DICOM_AVAILABLE:
raise RuntimeError("pydicom not installed. Install with: pip install pydicom")
# Parse DICOM
us = dcmread(BytesIO(file_bytes))
# Extract ultrasound region
ds = get_us_region_from_dcm(us, sv_mc_y=SMALL_VIEW_MARGIN_CROP_Y)
# Remove text box
img = remove_text_box(ds.copy(), box_background_pixel=YELLOW_BOX_BACKGROUND_PIXEL,
min_rect_area=MIN_YELLOW_BOX_RECT_AREA)
# Extract fan region
fan = get_fan_region(img, threshold=INTENSITY_THRESHOLD)
# Detect annotations
image_grey = fan.copy()
mask_inpaint = detect_annotation(fan)
mask_inpaint = cv2.dilate(mask_inpaint, MASK_INPAINTING_DILATE_KERNEL)
# Inpaint to remove text
dst = cv2.inpaint(image_grey, mask_inpaint, INPAINT_RADIUS, cv2.INPAINT_TELEA)
dst = dst / np.max(dst) if np.max(dst) > 0 else dst
# Denoise
if SKIMAGE_AVAILABLE:
sigma = estimate_sigma(dst, channel_axis=-1, average_sigmas=True)
median = denoise_nl_means(dst, h=0.8 * sigma, fast_mode=True, **DENOISE_NL_MEANS_PATCH_KW)
median = np.clip(median * 255, 0, 255).astype(np.uint8)
else:
median = np.clip(dst * 255, 0, 255).astype(np.uint8)
# Pad to square (model will resize to 224×224)
img = pad_to_square(median)
padded_size = max(img.shape[:2])
# Extract metadata
try:
rows = getattr(us, 'Rows', fan.shape[0])
columns = getattr(us, 'Columns', fan.shape[1])
if hasattr(us, 'PixelSpacing') and us.PixelSpacing is not None:
orig_pixel_spacing = [float(sp) for sp in us.PixelSpacing]
else:
orig_pixel_spacing = [1.0, 1.0]
except Exception:
rows = fan.shape[0]
columns = fan.shape[1]
orig_pixel_spacing = [1.0, 1.0]
metadata = {
'original_size': (rows, columns),
'original_pixel_spacing': orig_pixel_spacing,
'fan_size': (fan.shape[0], fan.shape[1]),
'padded_size': padded_size,
'pixel_spacing': orig_pixel_spacing[0], # Original spacing, model.py adjusts for resize
'processed_size': (padded_size, padded_size),
}
# Convert to PIL
if len(img.shape) == 2:
img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
else:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(img_rgb)
steps_applied = [
"dicom_parsing",
"us_region_extraction",
"text_box_removal",
"fan_extraction",
"annotation_detection",
"inpainting",
"denoising" if SKIMAGE_AVAILABLE else "normalization",
"square_padding",
]
return img_pil, {
"type": "dicom",
"pipeline": "full",
"steps_applied": steps_applied,
"metadata": metadata
}
def preprocess_image(image: Image.Image) -> Tuple[Image.Image, Dict]:
"""
Basic image preprocessing pipeline.
Steps:
1. Convert to RGB if needed
2. Pad to square
3. (Model will resize to 224)
Returns:
Tuple of (PIL Image, preprocessing info dict)
"""
# Convert to RGB
if image.mode not in ('RGB', 'L'):
image = image.convert('RGB')
width, height = image.size
max_side = max(width, height)
# Create square image with black padding
padding_color = (0, 0, 0) if image.mode == "RGB" else 0
new_image = Image.new(image.mode, (max_side, max_side), padding_color)
# Center the original
padding_left = (max_side - width) // 2
padding_top = (max_side - height) // 2
new_image.paste(image, (padding_left, padding_top))
# Ensure RGB
if new_image.mode == 'L':
new_image = new_image.convert('RGB')
steps_applied = [
"rgb_conversion",
"square_padding",
]
return new_image, {
"type": "image",
"pipeline": "basic",
"steps_applied": steps_applied,
"metadata": {
"original_size": (height, width),
"processed_size": (max_side, max_side),
}
}
def is_dicom_file(file_bytes: bytes, filename: str) -> bool:
"""Check if file is a DICOM file."""
# Check by extension
lower_name = filename.lower()
if lower_name.endswith('.dcm') or lower_name.endswith('.dicom'):
return True
# Check DICOM magic number (DICM at offset 128)
if len(file_bytes) > 132:
if file_bytes[128:132] == b'DICM':
return True
return False
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string."""
buffered = BytesIO()
image.save(buffered, format="PNG")
import base64
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def get_dicom_preview(file_bytes: bytes) -> str:
"""Extract raw image from DICOM for preview (no preprocessing)."""
if not DICOM_AVAILABLE:
raise RuntimeError("pydicom not installed")
us = dcmread(BytesIO(file_bytes))
ds = us.pixel_array
# Handle color space
if hasattr(us, 'PhotometricInterpretation'):
if 'ybr_full' in us.PhotometricInterpretation.lower():
ds = convert_color_space(ds, "YBR_FULL", "RGB", per_frame=True)
# Handle video (take first frame)
if len(ds.shape) == 4:
ds = ds[0]
# Normalize to 0-255
if ds.max() > 255:
ds = ((ds - ds.min()) / (ds.max() - ds.min()) * 255).astype(np.uint8)
# Convert to RGB if grayscale
if len(ds.shape) == 2:
ds = cv2.cvtColor(ds, cv2.COLOR_GRAY2RGB)
elif ds.shape[2] == 3:
ds = cv2.cvtColor(ds, cv2.COLOR_BGR2RGB)
img_pil = Image.fromarray(ds)
return image_to_base64(img_pil)
def preprocess_file(file_bytes: bytes, filename: str) -> Tuple[Image.Image, Dict]:
"""
Automatically detect file type and apply appropriate preprocessing.
Returns:
Tuple of (PIL Image, preprocessing info dict with base64 image)
"""
if is_dicom_file(file_bytes, filename):
processed_image, info = preprocess_dicom(file_bytes)
# Add base64 encoded image for frontend display
info["processed_image_base64"] = image_to_base64(processed_image)
return processed_image, info
else:
# Regular image
image = Image.open(BytesIO(file_bytes))
processed_image, info = preprocess_image(image)
info["processed_image_base64"] = image_to_base64(processed_image)
return processed_image, info