File size: 9,779 Bytes
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Core segmentation functions for NeuroSAM 3 application.
Handles segmentation operations, ROI statistics, and mask processing.
"""

from typing import Optional, Tuple, Dict, Any, List
import os
import tempfile
import numpy as np
import pydicom
from PIL import Image
import matplotlib.pyplot as plt
from scipy import ndimage
from logger_config import logger
from config import OUTPUT_DPI
from utils import combine_masks


def compare_with_ground_truth(
    pred_mask: np.ndarray,
    gt_mask_path: str
) -> Tuple[Optional[str], float, float]:
    """
    Compare SAM 3 prediction with ground truth mask and return comparison metrics.
    
    Args:
        pred_mask: Predicted mask array
        gt_mask_path: Path to ground truth mask image
    
    Returns:
        Tuple of (comparison_image_path, dice_score, iou_score)
    """
    try:
        gt_mask = Image.open(gt_mask_path)
        gt_array = np.array(gt_mask.convert('L')) > 127  # Binarize
        
        # Resize prediction mask to match ground truth if needed
        if pred_mask.shape != gt_array.shape:
            pred_pil = Image.fromarray((pred_mask * 255).astype(np.uint8))
            pred_pil = pred_pil.resize(gt_mask.size, Image.NEAREST)
            pred_mask = np.array(pred_pil) > 127
        
        # Calculate metrics
        intersection = np.logical_and(pred_mask, gt_array).sum()
        union = np.logical_or(pred_mask, gt_array).sum()
        dice_score = (
            (2.0 * intersection) / (pred_mask.sum() + gt_array.sum())
            if (pred_mask.sum() + gt_array.sum()) > 0
            else 0.0
        )
        iou_score = intersection / union if union > 0 else 0.0
        
        # Create comparison visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        axes[0].imshow(pred_mask, cmap='spring')
        axes[0].set_title('SAM 3 Prediction')
        axes[0].axis('off')
        
        axes[1].imshow(gt_array, cmap='cool')
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        # Overlay comparison
        comparison = np.zeros((*pred_mask.shape, 3))
        comparison[pred_mask & gt_array] = [0, 1, 0]  # Green: True Positive
        comparison[pred_mask & ~gt_array] = [1, 0, 0]  # Red: False Positive
        comparison[~pred_mask & gt_array] = [0, 0, 1]  # Blue: False Negative
        
        axes[2].imshow(comparison)
        axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}')
        axes[2].axis('off')
        
        plt.tight_layout()
        
        output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        output_path = output_file.name
        output_file.close()
        
        plt.savefig(output_path, bbox_inches='tight', dpi=OUTPUT_DPI)
        plt.close()
        
        return output_path, dice_score, iou_score
    except Exception as e:
        logger.error(f"Error comparing with ground truth: {e}", exc_info=True)
        return None, 0.0, 0.0


def calculate_roi_statistics(
    image_file: str,
    mask: np.ndarray,
    modality: str
) -> Dict[str, Any]:
    """
    Calculate ROI statistics from the segmented region.
    
    Args:
        image_file: Path to original image file
        mask: Binary mask array
        modality: Imaging modality ("CT" or "MRI")
    
    Returns:
        Dictionary with statistics including area, mean intensity, std, min, max, centroid
    """
    if mask is None or not isinstance(mask, np.ndarray):
        return {
            "error": "No valid mask available",
            "area_pixels": 0,
            "area_percentage": 0,
            "mean_intensity": 0,
            "std_intensity": 0,
            "min_intensity": 0,
            "max_intensity": 0,
            "centroid": (0, 0),
            "bounding_box": (0, 0, 0, 0)
        }
    
    try:
        # Load original image for intensity statistics
        file_path = str(image_file)
        file_ext = os.path.splitext(file_path)[1].lower()
        
        if file_ext == '.dcm':
            ds = pydicom.dcmread(file_path)
            img_array = ds.pixel_array.astype(np.float32)
            slope = getattr(ds, 'RescaleSlope', 1)
            intercept = getattr(ds, 'RescaleIntercept', 0)
            img_array = img_array * slope + intercept
        else:
            img = Image.open(file_path)
            if img.mode == 'RGB':
                img = img.convert('L')  # Convert to grayscale for intensity stats
            img_array = np.array(img).astype(np.float32)
        
        # Resize mask if needed
        if mask.shape != img_array.shape:
            zoom_factors = (
                img_array.shape[0] / mask.shape[0],
                img_array.shape[1] / mask.shape[1]
            )
            mask = ndimage.zoom(mask.astype(float), zoom_factors, order=0) > 0.5
        
        # Calculate statistics
        mask_bool = mask.astype(bool)
        total_pixels = mask.size
        roi_pixels = np.sum(mask_bool)
        
        if roi_pixels == 0:
            return {
                "error": "No pixels in ROI",
                "area_pixels": 0,
                "area_percentage": 0,
                "mean_intensity": 0,
                "std_intensity": 0,
                "min_intensity": 0,
                "max_intensity": 0,
                "centroid": (0, 0),
                "bounding_box": (0, 0, 0, 0)
            }
        
        # Intensity statistics
        roi_intensities = img_array[mask_bool]
        mean_intensity = float(np.mean(roi_intensities))
        std_intensity = float(np.std(roi_intensities))
        min_intensity = float(np.min(roi_intensities))
        max_intensity = float(np.max(roi_intensities))
        
        # Centroid
        y_coords, x_coords = np.where(mask_bool)
        centroid_y = float(np.mean(y_coords))
        centroid_x = float(np.mean(x_coords))
        
        # Bounding box
        if len(y_coords) > 0 and len(x_coords) > 0:
            bbox_y1 = int(np.min(y_coords))
            bbox_x1 = int(np.min(x_coords))
            bbox_y2 = int(np.max(y_coords))
            bbox_x2 = int(np.max(x_coords))
        else:
            bbox_y1 = bbox_x1 = bbox_y2 = bbox_x2 = 0
        
        area_percentage = (roi_pixels / total_pixels) * 100
        
        return {
            "area_pixels": int(roi_pixels),
            "area_percentage": float(area_percentage),
            "mean_intensity": mean_intensity,
            "std_intensity": std_intensity,
            "min_intensity": min_intensity,
            "max_intensity": max_intensity,
            "centroid": (centroid_x, centroid_y),
            "bounding_box": (bbox_x1, bbox_y1, bbox_x2, bbox_y2)
        }
    except Exception as e:
        logger.error(f"Error calculating ROI statistics: {e}", exc_info=True)
        return {
            "error": str(e),
            "area_pixels": 0,
            "area_percentage": 0,
            "mean_intensity": 0,
            "std_intensity": 0,
            "min_intensity": 0,
            "max_intensity": 0,
            "centroid": (0, 0),
            "bounding_box": (0, 0, 0, 0)
        }


def format_roi_statistics(stats: Dict[str, Any]) -> str:
    """
    Format ROI statistics dictionary into a readable string.
    
    Args:
        stats: Statistics dictionary from calculate_roi_statistics
    
    Returns:
        Formatted string with statistics
    """
    if "error" in stats:
        return f"❌ Error: {stats['error']}"
    
    return f"""
**ROI Statistics:**

- **Area**: {stats['area_pixels']} pixels ({stats['area_percentage']:.2f}% of image)
- **Intensity**:
  - Mean: {stats['mean_intensity']:.2f}
  - Std: {stats['std_intensity']:.2f}
  - Min: {stats['min_intensity']:.2f}
  - Max: {stats['max_intensity']:.2f}
- **Centroid**: ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f})
- **Bounding Box**: ({stats['bounding_box'][0]}, {stats['bounding_box'][1]}) to ({stats['bounding_box'][2]}, {stats['bounding_box'][3]})
"""


def generate_grid_points(
    image_size: Tuple[int, int],
    points_per_side: int = 32
) -> np.ndarray:
    """
    Generate a grid of points across the image for automatic mask generation.
    
    Args:
        image_size: Tuple of (height, width)
        points_per_side: Number of points per side of the grid
    
    Returns:
        Array of point coordinates (N, 2) where each row is [x, y]
    """
    height, width = image_size
    
    # Generate grid coordinates
    x_coords = np.linspace(0, width - 1, points_per_side)
    y_coords = np.linspace(0, height - 1, points_per_side)
    
    # Create meshgrid
    x_grid, y_grid = np.meshgrid(x_coords, y_coords)
    
    # Flatten and combine
    points = np.stack([x_grid.flatten(), y_grid.flatten()], axis=1)
    
    return points.astype(np.float32)


def calculate_dice_score(mask1: np.ndarray, mask2: np.ndarray) -> float:
    """
    Calculate Dice coefficient between two masks.
    
    Args:
        mask1: First binary mask
        mask2: Second binary mask
    
    Returns:
        Dice coefficient (0.0 to 1.0)
    """
    intersection = np.logical_and(mask1, mask2).sum()
    union = mask1.sum() + mask2.sum()
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return (2.0 * intersection) / union


def calculate_iou_score(mask1: np.ndarray, mask2: np.ndarray) -> float:
    """
    Calculate Intersection over Union (IoU) between two masks.
    
    Args:
        mask1: First binary mask
        mask2: Second binary mask
    
    Returns:
        IoU score (0.0 to 1.0)
    """
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return intersection / union