|
|
""" |
|
|
Core segmentation functions for NeuroSAM 3 application. |
|
|
Handles segmentation operations, ROI statistics, and mask processing. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Tuple, Dict, Any, List |
|
|
import os |
|
|
import tempfile |
|
|
import numpy as np |
|
|
import pydicom |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy import ndimage |
|
|
from logger_config import logger |
|
|
from config import OUTPUT_DPI |
|
|
from utils import combine_masks |
|
|
|
|
|
|
|
|
def compare_with_ground_truth( |
|
|
pred_mask: np.ndarray, |
|
|
gt_mask_path: str |
|
|
) -> Tuple[Optional[str], float, float]: |
|
|
""" |
|
|
Compare SAM 3 prediction with ground truth mask and return comparison metrics. |
|
|
|
|
|
Args: |
|
|
pred_mask: Predicted mask array |
|
|
gt_mask_path: Path to ground truth mask image |
|
|
|
|
|
Returns: |
|
|
Tuple of (comparison_image_path, dice_score, iou_score) |
|
|
""" |
|
|
try: |
|
|
gt_mask = Image.open(gt_mask_path) |
|
|
gt_array = np.array(gt_mask.convert('L')) > 127 |
|
|
|
|
|
|
|
|
if pred_mask.shape != gt_array.shape: |
|
|
pred_pil = Image.fromarray((pred_mask * 255).astype(np.uint8)) |
|
|
pred_pil = pred_pil.resize(gt_mask.size, Image.NEAREST) |
|
|
pred_mask = np.array(pred_pil) > 127 |
|
|
|
|
|
|
|
|
intersection = np.logical_and(pred_mask, gt_array).sum() |
|
|
union = np.logical_or(pred_mask, gt_array).sum() |
|
|
dice_score = ( |
|
|
(2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) |
|
|
if (pred_mask.sum() + gt_array.sum()) > 0 |
|
|
else 0.0 |
|
|
) |
|
|
iou_score = intersection / union if union > 0 else 0.0 |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
|
|
|
axes[0].imshow(pred_mask, cmap='spring') |
|
|
axes[0].set_title('SAM 3 Prediction') |
|
|
axes[0].axis('off') |
|
|
|
|
|
axes[1].imshow(gt_array, cmap='cool') |
|
|
axes[1].set_title('Ground Truth') |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
comparison = np.zeros((*pred_mask.shape, 3)) |
|
|
comparison[pred_mask & gt_array] = [0, 1, 0] |
|
|
comparison[pred_mask & ~gt_array] = [1, 0, 0] |
|
|
comparison[~pred_mask & gt_array] = [0, 0, 1] |
|
|
|
|
|
axes[2].imshow(comparison) |
|
|
axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}') |
|
|
axes[2].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=OUTPUT_DPI) |
|
|
plt.close() |
|
|
|
|
|
return output_path, dice_score, iou_score |
|
|
except Exception as e: |
|
|
logger.error(f"Error comparing with ground truth: {e}", exc_info=True) |
|
|
return None, 0.0, 0.0 |
|
|
|
|
|
|
|
|
def calculate_roi_statistics( |
|
|
image_file: str, |
|
|
mask: np.ndarray, |
|
|
modality: str |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Calculate ROI statistics from the segmented region. |
|
|
|
|
|
Args: |
|
|
image_file: Path to original image file |
|
|
mask: Binary mask array |
|
|
modality: Imaging modality ("CT" or "MRI") |
|
|
|
|
|
Returns: |
|
|
Dictionary with statistics including area, mean intensity, std, min, max, centroid |
|
|
""" |
|
|
if mask is None or not isinstance(mask, np.ndarray): |
|
|
return { |
|
|
"error": "No valid mask available", |
|
|
"area_pixels": 0, |
|
|
"area_percentage": 0, |
|
|
"mean_intensity": 0, |
|
|
"std_intensity": 0, |
|
|
"min_intensity": 0, |
|
|
"max_intensity": 0, |
|
|
"centroid": (0, 0), |
|
|
"bounding_box": (0, 0, 0, 0) |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
else: |
|
|
img = Image.open(file_path) |
|
|
if img.mode == 'RGB': |
|
|
img = img.convert('L') |
|
|
img_array = np.array(img).astype(np.float32) |
|
|
|
|
|
|
|
|
if mask.shape != img_array.shape: |
|
|
zoom_factors = ( |
|
|
img_array.shape[0] / mask.shape[0], |
|
|
img_array.shape[1] / mask.shape[1] |
|
|
) |
|
|
mask = ndimage.zoom(mask.astype(float), zoom_factors, order=0) > 0.5 |
|
|
|
|
|
|
|
|
mask_bool = mask.astype(bool) |
|
|
total_pixels = mask.size |
|
|
roi_pixels = np.sum(mask_bool) |
|
|
|
|
|
if roi_pixels == 0: |
|
|
return { |
|
|
"error": "No pixels in ROI", |
|
|
"area_pixels": 0, |
|
|
"area_percentage": 0, |
|
|
"mean_intensity": 0, |
|
|
"std_intensity": 0, |
|
|
"min_intensity": 0, |
|
|
"max_intensity": 0, |
|
|
"centroid": (0, 0), |
|
|
"bounding_box": (0, 0, 0, 0) |
|
|
} |
|
|
|
|
|
|
|
|
roi_intensities = img_array[mask_bool] |
|
|
mean_intensity = float(np.mean(roi_intensities)) |
|
|
std_intensity = float(np.std(roi_intensities)) |
|
|
min_intensity = float(np.min(roi_intensities)) |
|
|
max_intensity = float(np.max(roi_intensities)) |
|
|
|
|
|
|
|
|
y_coords, x_coords = np.where(mask_bool) |
|
|
centroid_y = float(np.mean(y_coords)) |
|
|
centroid_x = float(np.mean(x_coords)) |
|
|
|
|
|
|
|
|
if len(y_coords) > 0 and len(x_coords) > 0: |
|
|
bbox_y1 = int(np.min(y_coords)) |
|
|
bbox_x1 = int(np.min(x_coords)) |
|
|
bbox_y2 = int(np.max(y_coords)) |
|
|
bbox_x2 = int(np.max(x_coords)) |
|
|
else: |
|
|
bbox_y1 = bbox_x1 = bbox_y2 = bbox_x2 = 0 |
|
|
|
|
|
area_percentage = (roi_pixels / total_pixels) * 100 |
|
|
|
|
|
return { |
|
|
"area_pixels": int(roi_pixels), |
|
|
"area_percentage": float(area_percentage), |
|
|
"mean_intensity": mean_intensity, |
|
|
"std_intensity": std_intensity, |
|
|
"min_intensity": min_intensity, |
|
|
"max_intensity": max_intensity, |
|
|
"centroid": (centroid_x, centroid_y), |
|
|
"bounding_box": (bbox_x1, bbox_y1, bbox_x2, bbox_y2) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating ROI statistics: {e}", exc_info=True) |
|
|
return { |
|
|
"error": str(e), |
|
|
"area_pixels": 0, |
|
|
"area_percentage": 0, |
|
|
"mean_intensity": 0, |
|
|
"std_intensity": 0, |
|
|
"min_intensity": 0, |
|
|
"max_intensity": 0, |
|
|
"centroid": (0, 0), |
|
|
"bounding_box": (0, 0, 0, 0) |
|
|
} |
|
|
|
|
|
|
|
|
def format_roi_statistics(stats: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Format ROI statistics dictionary into a readable string. |
|
|
|
|
|
Args: |
|
|
stats: Statistics dictionary from calculate_roi_statistics |
|
|
|
|
|
Returns: |
|
|
Formatted string with statistics |
|
|
""" |
|
|
if "error" in stats: |
|
|
return f"❌ Error: {stats['error']}" |
|
|
|
|
|
return f""" |
|
|
**ROI Statistics:** |
|
|
|
|
|
- **Area**: {stats['area_pixels']} pixels ({stats['area_percentage']:.2f}% of image) |
|
|
- **Intensity**: |
|
|
- Mean: {stats['mean_intensity']:.2f} |
|
|
- Std: {stats['std_intensity']:.2f} |
|
|
- Min: {stats['min_intensity']:.2f} |
|
|
- Max: {stats['max_intensity']:.2f} |
|
|
- **Centroid**: ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f}) |
|
|
- **Bounding Box**: ({stats['bounding_box'][0]}, {stats['bounding_box'][1]}) to ({stats['bounding_box'][2]}, {stats['bounding_box'][3]}) |
|
|
""" |
|
|
|
|
|
|
|
|
def generate_grid_points( |
|
|
image_size: Tuple[int, int], |
|
|
points_per_side: int = 32 |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Generate a grid of points across the image for automatic mask generation. |
|
|
|
|
|
Args: |
|
|
image_size: Tuple of (height, width) |
|
|
points_per_side: Number of points per side of the grid |
|
|
|
|
|
Returns: |
|
|
Array of point coordinates (N, 2) where each row is [x, y] |
|
|
""" |
|
|
height, width = image_size |
|
|
|
|
|
|
|
|
x_coords = np.linspace(0, width - 1, points_per_side) |
|
|
y_coords = np.linspace(0, height - 1, points_per_side) |
|
|
|
|
|
|
|
|
x_grid, y_grid = np.meshgrid(x_coords, y_coords) |
|
|
|
|
|
|
|
|
points = np.stack([x_grid.flatten(), y_grid.flatten()], axis=1) |
|
|
|
|
|
return points.astype(np.float32) |
|
|
|
|
|
|
|
|
def calculate_dice_score(mask1: np.ndarray, mask2: np.ndarray) -> float: |
|
|
""" |
|
|
Calculate Dice coefficient between two masks. |
|
|
|
|
|
Args: |
|
|
mask1: First binary mask |
|
|
mask2: Second binary mask |
|
|
|
|
|
Returns: |
|
|
Dice coefficient (0.0 to 1.0) |
|
|
""" |
|
|
intersection = np.logical_and(mask1, mask2).sum() |
|
|
union = mask1.sum() + mask2.sum() |
|
|
if union == 0: |
|
|
return 1.0 if intersection == 0 else 0.0 |
|
|
return (2.0 * intersection) / union |
|
|
|
|
|
|
|
|
def calculate_iou_score(mask1: np.ndarray, mask2: np.ndarray) -> float: |
|
|
""" |
|
|
Calculate Intersection over Union (IoU) between two masks. |
|
|
|
|
|
Args: |
|
|
mask1: First binary mask |
|
|
mask2: Second binary mask |
|
|
|
|
|
Returns: |
|
|
IoU score (0.0 to 1.0) |
|
|
""" |
|
|
intersection = np.logical_and(mask1, mask2).sum() |
|
|
union = np.logical_or(mask1, mask2).sum() |
|
|
if union == 0: |
|
|
return 1.0 if intersection == 0 else 0.0 |
|
|
return intersection / union |
|
|
|
|
|
|