""" Core segmentation functions for NeuroSAM 3 application. Handles segmentation operations, ROI statistics, and mask processing. """ from typing import Optional, Tuple, Dict, Any, List import os import tempfile import numpy as np import pydicom from PIL import Image import matplotlib.pyplot as plt from scipy import ndimage from logger_config import logger from config import OUTPUT_DPI from utils import combine_masks def compare_with_ground_truth( pred_mask: np.ndarray, gt_mask_path: str ) -> Tuple[Optional[str], float, float]: """ Compare SAM 3 prediction with ground truth mask and return comparison metrics. Args: pred_mask: Predicted mask array gt_mask_path: Path to ground truth mask image Returns: Tuple of (comparison_image_path, dice_score, iou_score) """ try: gt_mask = Image.open(gt_mask_path) gt_array = np.array(gt_mask.convert('L')) > 127 # Binarize # Resize prediction mask to match ground truth if needed if pred_mask.shape != gt_array.shape: pred_pil = Image.fromarray((pred_mask * 255).astype(np.uint8)) pred_pil = pred_pil.resize(gt_mask.size, Image.NEAREST) pred_mask = np.array(pred_pil) > 127 # Calculate metrics intersection = np.logical_and(pred_mask, gt_array).sum() union = np.logical_or(pred_mask, gt_array).sum() dice_score = ( (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0 ) iou_score = intersection / union if union > 0 else 0.0 # Create comparison visualization fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(pred_mask, cmap='spring') axes[0].set_title('SAM 3 Prediction') axes[0].axis('off') axes[1].imshow(gt_array, cmap='cool') axes[1].set_title('Ground Truth') axes[1].axis('off') # Overlay comparison comparison = np.zeros((*pred_mask.shape, 3)) comparison[pred_mask & gt_array] = [0, 1, 0] # Green: True Positive comparison[pred_mask & ~gt_array] = [1, 0, 0] # Red: False Positive comparison[~pred_mask & gt_array] = [0, 0, 1] # Blue: False Negative axes[2].imshow(comparison) axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}') axes[2].axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=OUTPUT_DPI) plt.close() return output_path, dice_score, iou_score except Exception as e: logger.error(f"Error comparing with ground truth: {e}", exc_info=True) return None, 0.0, 0.0 def calculate_roi_statistics( image_file: str, mask: np.ndarray, modality: str ) -> Dict[str, Any]: """ Calculate ROI statistics from the segmented region. Args: image_file: Path to original image file mask: Binary mask array modality: Imaging modality ("CT" or "MRI") Returns: Dictionary with statistics including area, mean intensity, std, min, max, centroid """ if mask is None or not isinstance(mask, np.ndarray): return { "error": "No valid mask available", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } try: # Load original image for intensity statistics file_path = str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept else: img = Image.open(file_path) if img.mode == 'RGB': img = img.convert('L') # Convert to grayscale for intensity stats img_array = np.array(img).astype(np.float32) # Resize mask if needed if mask.shape != img_array.shape: zoom_factors = ( img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1] ) mask = ndimage.zoom(mask.astype(float), zoom_factors, order=0) > 0.5 # Calculate statistics mask_bool = mask.astype(bool) total_pixels = mask.size roi_pixels = np.sum(mask_bool) if roi_pixels == 0: return { "error": "No pixels in ROI", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } # Intensity statistics roi_intensities = img_array[mask_bool] mean_intensity = float(np.mean(roi_intensities)) std_intensity = float(np.std(roi_intensities)) min_intensity = float(np.min(roi_intensities)) max_intensity = float(np.max(roi_intensities)) # Centroid y_coords, x_coords = np.where(mask_bool) centroid_y = float(np.mean(y_coords)) centroid_x = float(np.mean(x_coords)) # Bounding box if len(y_coords) > 0 and len(x_coords) > 0: bbox_y1 = int(np.min(y_coords)) bbox_x1 = int(np.min(x_coords)) bbox_y2 = int(np.max(y_coords)) bbox_x2 = int(np.max(x_coords)) else: bbox_y1 = bbox_x1 = bbox_y2 = bbox_x2 = 0 area_percentage = (roi_pixels / total_pixels) * 100 return { "area_pixels": int(roi_pixels), "area_percentage": float(area_percentage), "mean_intensity": mean_intensity, "std_intensity": std_intensity, "min_intensity": min_intensity, "max_intensity": max_intensity, "centroid": (centroid_x, centroid_y), "bounding_box": (bbox_x1, bbox_y1, bbox_x2, bbox_y2) } except Exception as e: logger.error(f"Error calculating ROI statistics: {e}", exc_info=True) return { "error": str(e), "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } def format_roi_statistics(stats: Dict[str, Any]) -> str: """ Format ROI statistics dictionary into a readable string. Args: stats: Statistics dictionary from calculate_roi_statistics Returns: Formatted string with statistics """ if "error" in stats: return f"❌ Error: {stats['error']}" return f""" **ROI Statistics:** - **Area**: {stats['area_pixels']} pixels ({stats['area_percentage']:.2f}% of image) - **Intensity**: - Mean: {stats['mean_intensity']:.2f} - Std: {stats['std_intensity']:.2f} - Min: {stats['min_intensity']:.2f} - Max: {stats['max_intensity']:.2f} - **Centroid**: ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f}) - **Bounding Box**: ({stats['bounding_box'][0]}, {stats['bounding_box'][1]}) to ({stats['bounding_box'][2]}, {stats['bounding_box'][3]}) """ def generate_grid_points( image_size: Tuple[int, int], points_per_side: int = 32 ) -> np.ndarray: """ Generate a grid of points across the image for automatic mask generation. Args: image_size: Tuple of (height, width) points_per_side: Number of points per side of the grid Returns: Array of point coordinates (N, 2) where each row is [x, y] """ height, width = image_size # Generate grid coordinates x_coords = np.linspace(0, width - 1, points_per_side) y_coords = np.linspace(0, height - 1, points_per_side) # Create meshgrid x_grid, y_grid = np.meshgrid(x_coords, y_coords) # Flatten and combine points = np.stack([x_grid.flatten(), y_grid.flatten()], axis=1) return points.astype(np.float32) def calculate_dice_score(mask1: np.ndarray, mask2: np.ndarray) -> float: """ Calculate Dice coefficient between two masks. Args: mask1: First binary mask mask2: Second binary mask Returns: Dice coefficient (0.0 to 1.0) """ intersection = np.logical_and(mask1, mask2).sum() union = mask1.sum() + mask2.sum() if union == 0: return 1.0 if intersection == 0 else 0.0 return (2.0 * intersection) / union def calculate_iou_score(mask1: np.ndarray, mask2: np.ndarray) -> float: """ Calculate Intersection over Union (IoU) between two masks. Args: mask1: First binary mask mask2: Second binary mask Returns: IoU score (0.0 to 1.0) """ intersection = np.logical_and(mask1, mask2).sum() union = np.logical_or(mask1, mask2).sum() if union == 0: return 1.0 if intersection == 0 else 0.0 return intersection / union