NeuroSAM3 / dicom_utils.py
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
DICOM processing utilities for NeuroSAM 3 application.
Handles DICOM file reading, windowing, and image preprocessing.
"""
from typing import Tuple, Optional
import numpy as np
import pydicom
from pydicom.errors import InvalidDicomError
from PIL import Image
from logger_config import logger
from config import CT_WINDOW_PRESETS, OUTPUT_DPI
def get_window_params(window_type: str, modality: str) -> Tuple[float, float]:
"""
Get window level and width parameters based on window type and modality.
Args:
window_type: Window type name (e.g., "Brain (Grey Matter)")
modality: Imaging modality ("CT" or "MRI")
Returns:
Tuple of (level, width)
"""
if modality == "CT":
preset = CT_WINDOW_PRESETS.get(window_type, CT_WINDOW_PRESETS["Default"])
return preset["level"], preset["width"]
else:
# MRI doesn't use windowing presets
return 0.0, 0.0
def apply_ct_windowing(img_hu: np.ndarray, level: float, width: float) -> np.ndarray:
"""
Apply CT windowing to Hounsfield units.
Args:
img_hu: Image in Hounsfield units
level: Window level
width: Window width
Returns:
Windowed image array (0-1 normalized)
"""
img_min = level - (width / 2)
img_max = level + (width / 2)
img_range = img_max - img_min
if img_range <= 0:
# Fallback to full range
img_min = np.min(img_hu)
img_max = np.max(img_hu)
img_range = img_max - img_min
if img_range <= 0:
raise ValueError("Invalid image range for windowing")
img_windowed = (img_hu - img_min) / img_range
img_windowed = np.clip(img_windowed, 0, 1)
return img_windowed
def apply_mri_normalization(img_array: np.ndarray) -> np.ndarray:
"""
Apply percentile-based normalization for MRI images.
Args:
img_array: Image array
Returns:
Normalized image array (0-1 normalized)
"""
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_range = img_max - img_min
if img_range <= 0:
# Fallback to full range
img_min = np.min(img_array)
img_max = np.max(img_array)
img_range = img_max - img_min
if img_range <= 0:
raise ValueError("Invalid image range for normalization")
img_normalized = (img_array - img_min) / img_range
img_normalized = np.clip(img_normalized, 0, 1)
return img_normalized
def read_dicom_file(file_path: str) -> Tuple[np.ndarray, Optional[pydicom.Dataset]]:
"""
Read DICOM file and extract pixel data.
Args:
file_path: Path to DICOM file
Returns:
Tuple of (pixel_array, dataset) or raises exception
Raises:
InvalidDicomError: If file is not a valid DICOM file
ValueError: If DICOM file doesn't contain pixel data
"""
try:
ds = pydicom.dcmread(file_path)
if not hasattr(ds, 'pixel_array'):
raise ValueError("DICOM file does not contain pixel data")
raw = ds.pixel_array.astype(np.float32)
# Apply rescale slope and intercept
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_hu = raw * slope + intercept
logger.debug(f"DICOM file read: {file_path}, shape={img_hu.shape}")
return img_hu, ds
except InvalidDicomError as e:
logger.error(f"Invalid DICOM file format: {file_path}, error: {e}")
raise
except Exception as e:
logger.error(f"Error reading DICOM file: {file_path}, error: {e}")
raise
def process_dicom_to_pil(
file_path: str,
modality: str,
window_type: str
) -> Image.Image:
"""
Process DICOM file and convert to PIL Image.
Args:
file_path: Path to DICOM file
modality: Imaging modality ("CT" or "MRI")
window_type: Window type for CT images
Returns:
PIL Image ready for processing
Raises:
InvalidDicomError: If file is not a valid DICOM file
ValueError: If processing fails
"""
img_hu, ds = read_dicom_file(file_path)
# Apply windowing/normalization based on modality
if modality == "CT":
level, width = get_window_params(window_type, modality)
img_windowed = apply_ct_windowing(img_hu, level, width)
else: # MRI
img_windowed = apply_mri_normalization(img_hu)
# Convert to uint8
img_uint8 = (img_windowed * 255).astype(np.uint8)
# Convert to PIL Image
if len(img_uint8.shape) == 2:
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.fromarray(img_uint8)
logger.debug(f"DICOM processed to PIL Image: shape={img_uint8.shape}")
return pil_image
def process_standard_image_to_pil(
file_path: str,
modality: str,
window_type: str
) -> Image.Image:
"""
Process standard image file (PNG, JPG, etc.) and convert to PIL Image.
Args:
file_path: Path to image file
modality: Imaging modality ("CT" or "MRI")
window_type: Window type for CT images
Returns:
PIL Image ready for processing
Raises:
ValueError: If processing fails
"""
pil_image = Image.open(file_path)
# Convert to RGB if needed
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
# Convert to numpy for normalization
img_array = np.array(pil_image)
# Handle grayscale images
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize image based on modality
img_float = img_array.astype(np.float32)
if modality == "CT":
# For CT-like processing, use windowing
level, width = get_window_params(window_type, modality)
# Apply windowing to each channel
img_normalized = np.zeros_like(img_float)
for c in range(img_float.shape[2]):
channel_hu = img_float[:, :, c]
img_normalized[:, :, c] = apply_ct_windowing(channel_hu, level, width)
else: # MRI - use percentile normalization
img_normalized = apply_mri_normalization(img_float)
# Convert back to uint8
img_uint8 = (img_normalized * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8.astype(np.uint8))
logger.debug(f"Standard image processed to PIL Image: shape={img_uint8.shape}")
return pil_image
def is_dicom_file(file_path: str) -> bool:
"""
Check if file is a DICOM file based on extension.
Args:
file_path: Path to file
Returns:
True if file is DICOM, False otherwise
"""
import os
ext = os.path.splitext(file_path)[1].lower()
return ext == '.dcm'