""" Utility functions for NeuroSAM 3 application. Helper functions for image processing, visualization, and common operations. """ from typing import Optional, Tuple, List, Dict, Any import os import re import tempfile import numpy as np import pydicom from PIL import Image import matplotlib.pyplot as plt from logger_config import logger def extract_subject_id(file_path: str) -> Tuple[str, str, str]: """ Extract subject/patient ID from file path. Common patterns: - Folder name: /subject_001/image.png -> subject_001 - Filename prefix: subject_001_slice_01.png -> subject_001 - Patient ID in filename: patient_123_slice_5.dcm -> patient_123 - Study UID in DICOM: extract from DICOM metadata Args: file_path: Path to file Returns: Tuple of (subject_id, confidence_level, source) confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback) source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback' """ file_path = str(file_path) filename = os.path.basename(file_path) dir_path = os.path.dirname(file_path) # HIGHEST CONFIDENCE: DICOM metadata (most reliable) if file_path.lower().endswith('.dcm'): try: ds = pydicom.dcmread(file_path, stop_before_pixels=True) patient_id = getattr(ds, 'PatientID', None) if patient_id and patient_id.strip(): return f"patient_{patient_id}", 'high', 'dicom_patientid' study_uid = getattr(ds, 'StudyInstanceUID', None) if study_uid: # Use full study UID as identifier (unique per study) return f"study_{study_uid}", 'high', 'dicom_study' except Exception as e: logger.debug(f"Could not read DICOM metadata: {e}") # MEDIUM CONFIDENCE: Folder name (common in medical datasets) folder_name = os.path.basename(dir_path.rstrip('/')) if folder_name and folder_name not in ['', '.', '..']: # Check if folder name looks like a subject ID if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I): return folder_name, 'medium', 'folder' # MEDIUM CONFIDENCE: Filename pattern patterns = [ (r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'), # subject_001, patient_123 (r'([A-Z]{2,}\d+)', 'medium'), # BR001, MR123, etc. ] for pattern, confidence in patterns: match = re.search(pattern, filename, re.I) if match: if len(match.groups()) > 1: return f"{match.group(1)}_{match.group(2)}", confidence, 'filename' else: return match.group(1), confidence, 'filename' # LOW CONFIDENCE: Numeric pattern (could be slice number, not patient ID) numeric_match = re.search(r'(\d{3,})', filename) if numeric_match: return numeric_match.group(1), 'low', 'filename_numeric' # LOWEST CONFIDENCE: Fallback to filename base_name = os.path.splitext(filename)[0] if len(base_name) > 0: return base_name, 'low', 'fallback' return "unknown", 'low', 'unknown' def group_images_by_subject(image_files: List[str]) -> Dict[str, Dict[str, Any]]: """ Group image files by subject/patient ID. Args: image_files: List of file paths Returns: Dictionary: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}} """ if not image_files: return {} if isinstance(image_files, str): image_files = [image_files] # Filter out None files image_files = [f for f in image_files if f is not None] # Group by subject ID and track confidence subject_groups = {} for file_path in image_files: subject_id, confidence, source = extract_subject_id(file_path) if subject_id not in subject_groups: subject_groups[subject_id] = { 'files': [], 'confidence': confidence, 'sources': set([source]) } subject_groups[subject_id]['files'].append(file_path) subject_groups[subject_id]['sources'].add(source) # Upgrade confidence if we find high-confidence source if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'): subject_groups[subject_id]['confidence'] = confidence # Sort files within each group (by filename) for subject_id in subject_groups: subject_groups[subject_id]['files'].sort() subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources']) return subject_groups def combine_masks(masks) -> Optional[np.ndarray]: """ Combine multiple mask arrays into a single mask. Args: masks: List of mask arrays, or numpy array, or None Returns: Combined mask array or None if no valid masks """ if masks is None: return None # Handle numpy array input (convert to list) if isinstance(masks, np.ndarray): if masks.ndim == 0: # Scalar return None elif masks.ndim == 1: # 1D array - might be empty if len(masks) == 0: return None masks = [masks] # Convert to list else: # Multi-dimensional array - treat as single mask return masks # Handle list/tuple input if not isinstance(masks, (list, tuple)): # Try to convert to list try: masks = list(masks) except Exception: return None if len(masks) == 0: return None mask_arrays = [] for mask in masks: if isinstance(mask, np.ndarray): mask_arrays.append(mask) else: # Try to convert to numpy try: mask_np = np.array(mask) if mask_np.size > 0: # Only add non-empty arrays mask_arrays.append(mask_np) except Exception as e: logger.debug(f"Could not convert mask to numpy: {e}") continue if len(mask_arrays) == 0: return None # Combine all masks using logical OR try: # Ensure all masks have the same shape and are 2D # First, convert any 3D masks to 2D mask_arrays_2d = [] for mask in mask_arrays: if mask.ndim == 3: # If 3D, take first channel or convert to grayscale if mask.shape[0] == 3 or mask.shape[2] == 3: if mask.shape[0] == 3: mask = np.mean(mask, axis=0) > 0.5 else: mask = np.mean(mask, axis=2) > 0.5 else: mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0] elif mask.ndim > 3: mask = mask.squeeze() if mask.ndim != 2: mask = mask.reshape(mask.shape[-2], mask.shape[-1]) # Ensure boolean if mask.dtype != bool: mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2) mask_arrays_2d.append(mask) # Resize masks to same shape if needed if len(mask_arrays_2d) > 1: target_shape = mask_arrays_2d[0].shape for i in range(1, len(mask_arrays_2d)): if mask_arrays_2d[i].shape != target_shape: from scipy.ndimage import zoom zoom_factors = ( target_shape[0] / mask_arrays_2d[i].shape[0], target_shape[1] / mask_arrays_2d[i].shape[1] ) mask_arrays_2d[i] = zoom(mask_arrays_2d[i].astype(float), zoom_factors, order=0) > 0.5 combined_mask = np.any(mask_arrays_2d, axis=0) return combined_mask.astype(bool) except Exception as e: logger.error(f"Error combining masks: {e}", exc_info=True) return None def create_output_image( pil_image: Image.Image, mask: Optional[np.ndarray], prompt_text: str, colormap: str = 'spring', transparency: float = 0.5, title: Optional[str] = None ) -> str: """ Create output visualization image with mask overlay. Args: pil_image: Base PIL image mask: Optional mask array to overlay (2D or 3D) prompt_text: Prompt text for title colormap: Matplotlib colormap name transparency: Mask transparency (0.0-1.0) title: Optional custom title Returns: Path to saved output image """ plt.figure(figsize=(10, 10)) plt.imshow(pil_image) if mask is not None: # Ensure mask is 2D for matplotlib imshow if isinstance(mask, np.ndarray): if mask.ndim == 3: # If 3D, take first channel or convert to grayscale if mask.shape[0] == 3 or mask.shape[2] == 3: # RGB-like format: convert to grayscale if mask.shape[0] == 3: # Shape is (3, H, W) - take mean across channels mask = np.mean(mask, axis=0) else: # Shape is (H, W, 3) - convert to grayscale mask = np.mean(mask, axis=2) else: # Take first channel mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0] elif mask.ndim > 3: # Flatten extra dimensions mask = mask.squeeze() if mask.ndim != 2: logger.warning(f"Mask has {mask.ndim} dimensions, expected 2D. Flattening...") mask = mask.reshape(mask.shape[-2], mask.shape[-1]) # Ensure mask is boolean or binary (0-1 range) if mask.dtype != bool: # Convert to boolean if not already mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2) plt.imshow(mask, alpha=transparency, cmap=colormap) plt.axis('off') display_title = title or f"Segmentation: {prompt_text}" plt.title(display_title, fontsize=12, pad=10) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() from config import OUTPUT_DPI plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=OUTPUT_DPI) plt.close() return output_path def create_demo_dicom_file(output_path: str = "demo_brain_mri.dcm") -> bool: """ Create a demo DICOM file for testing. Args: output_path: Path where to save the demo file Returns: True if successful, False otherwise """ try: from pydicom.data import get_testdata_file test_file = get_testdata_file("MR_small.dcm") if test_file and os.path.exists(test_file): import shutil shutil.copy(test_file, output_path) logger.info(f"Demo file ready: {output_path}") return True except Exception as e: logger.debug(f"Could not copy test DICOM file: {e}") try: # Create synthetic DICOM file from pydicom.dataset import FileDataset, FileMetaDataset from pydicom.uid import generate_uid synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16) center_x, center_y = 128, 128 y, x = np.ogrid[:256, :256] mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2 synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255) file_meta = FileMetaDataset() file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4' file_meta.MediaStorageSOPInstanceUID = generate_uid() file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1' ds = FileDataset(output_path, {}, file_meta=file_meta, preamble=b"\x00" * 128) ds.PatientName = "Demo^Patient" ds.PatientID = "DEMO001" ds.Modality = "MR" ds.Rows = 256 ds.Columns = 256 ds.BitsAllocated = 16 ds.BitsStored = 16 ds.HighBit = 15 ds.SamplesPerPixel = 1 ds.PixelRepresentation = 0 ds.PhotometricInterpretation = "MONOCHROME2" ds.PixelSpacing = [1.0, 1.0] ds.RescaleIntercept = "0" ds.RescaleSlope = "1" ds.PixelData = synthetic_image.tobytes() ds.save_as(output_path, write_like_original=False) logger.info(f"Synthetic demo file created: {output_path}") return True except Exception as e: logger.warning(f"Could not create demo file: {e}") return False