mmrech commited on
Commit
db851de
·
1 Parent(s): fbc936c

Add ground truth comparison: compare SAM 3 segmentation with ground truth masks (BraTS, Kaggle datasets)

Browse files
Files changed (1) hide show
  1. app.py +175 -2
app.py CHANGED
@@ -93,8 +93,73 @@ except:
93
  except Exception as e:
94
  print(f"⚠️ Could not create demo file: {e}")
95
 
96
- def process_medical_image(image_file, prompt_text, modality, window_type):
97
- """Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if model is None or processor is None:
99
  print("❌ Error: Model not loaded.")
100
  return None
@@ -223,6 +288,7 @@ def process_medical_image(image_file, prompt_text, modality, window_type):
223
  plt.figure(figsize=(10, 10))
224
  plt.imshow(pil_image)
225
 
 
226
  if 'masks' in results and results['masks'] is not None:
227
  masks = results['masks'].cpu().numpy()
228
  if len(masks) > 0:
@@ -243,6 +309,8 @@ def process_medical_image(image_file, prompt_text, modality, window_type):
243
  plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
244
  plt.close()
245
 
 
 
246
  return output_path
247
 
248
  except pydicom.errors.InvalidDicomError as e:
@@ -279,6 +347,32 @@ def process_with_status(image_file, prompt_text, modality, window_type):
279
  else:
280
  return result, "✅ Segmentation complete!"
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  def process_sequence(image_files, prompt_text, modality, window_type):
283
  """Process multiple images from the same subject and return gallery of results."""
284
  if model is None or processor is None:
@@ -445,6 +539,78 @@ with gr.Blocks() as demo:
445
  interactive=False,
446
  lines=5
447
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
  # Single image processing
450
  load_demo_btn.click(
@@ -465,6 +631,13 @@ with gr.Blocks() as demo:
465
  inputs=[files_input, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
466
  outputs=[gallery_output, status_batch_text]
467
  )
 
 
 
 
 
 
 
468
 
469
  if __name__ == "__main__":
470
  demo.launch()
 
93
  except Exception as e:
94
  print(f"⚠️ Could not create demo file: {e}")
95
 
96
+ def compare_with_ground_truth(pred_mask, gt_mask_path):
97
+ """Compare SAM 3 prediction with ground truth mask and return comparison metrics."""
98
+ try:
99
+ gt_mask = Image.open(gt_mask_path)
100
+ gt_array = np.array(gt_mask.convert('L')) > 127 # Binarize
101
+
102
+ # Resize prediction mask to match ground truth if needed
103
+ if pred_mask.shape != gt_array.shape:
104
+ from PIL import Image as PILImage
105
+ pred_pil = PILImage.fromarray((pred_mask * 255).astype(np.uint8))
106
+ pred_pil = pred_pil.resize(gt_mask.size, PILImage.NEAREST)
107
+ pred_mask = np.array(pred_pil) > 127
108
+
109
+ # Calculate metrics
110
+ intersection = np.logical_and(pred_mask, gt_array).sum()
111
+ union = np.logical_or(pred_mask, gt_array).sum()
112
+ dice_score = (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0
113
+ iou_score = intersection / union if union > 0 else 0.0
114
+
115
+ # Create comparison visualization
116
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
117
+
118
+ axes[0].imshow(pred_mask, cmap='spring')
119
+ axes[0].set_title('SAM 3 Prediction')
120
+ axes[0].axis('off')
121
+
122
+ axes[1].imshow(gt_array, cmap='cool')
123
+ axes[1].set_title('Ground Truth')
124
+ axes[1].axis('off')
125
+
126
+ # Overlay comparison
127
+ comparison = np.zeros((*pred_mask.shape, 3))
128
+ comparison[pred_mask & gt_array] = [0, 1, 0] # Green: True Positive
129
+ comparison[pred_mask & ~gt_array] = [1, 0, 0] # Red: False Positive
130
+ comparison[~pred_mask & gt_array] = [0, 0, 1] # Blue: False Negative
131
+
132
+ axes[2].imshow(comparison)
133
+ axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}')
134
+ axes[2].axis('off')
135
+
136
+ plt.tight_layout()
137
+
138
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
139
+ output_path = output_file.name
140
+ output_file.close()
141
+
142
+ plt.savefig(output_path, bbox_inches='tight', dpi=100)
143
+ plt.close()
144
+
145
+ return output_path, dice_score, iou_score
146
+ except Exception as e:
147
+ print(f"⚠️ Error comparing with ground truth: {e}")
148
+ return None, 0.0, 0.0
149
+
150
+ def process_medical_image(image_file, prompt_text, modality, window_type, return_mask=False):
151
+ """Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3.
152
+
153
+ Args:
154
+ image_file: Path to image file
155
+ prompt_text: Text prompt for segmentation
156
+ modality: CT or MRI
157
+ window_type: Windowing strategy
158
+ return_mask: If True, also return the binary mask array
159
+
160
+ Returns:
161
+ Path to output image, and optionally the mask array
162
+ """
163
  if model is None or processor is None:
164
  print("❌ Error: Model not loaded.")
165
  return None
 
288
  plt.figure(figsize=(10, 10))
289
  plt.imshow(pil_image)
290
 
291
+ final_mask = None
292
  if 'masks' in results and results['masks'] is not None:
293
  masks = results['masks'].cpu().numpy()
294
  if len(masks) > 0:
 
309
  plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
310
  plt.close()
311
 
312
+ if return_mask:
313
+ return output_path, final_mask
314
  return output_path
315
 
316
  except pydicom.errors.InvalidDicomError as e:
 
347
  else:
348
  return result, "✅ Segmentation complete!"
349
 
350
+ def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type):
351
+ """Process image and compare with ground truth segmentation mask."""
352
+ if model is None or processor is None:
353
+ return None, None, 0.0, 0.0, "❌ Error: Model not loaded."
354
+
355
+ if image_file is None:
356
+ return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file."
357
+
358
+ if gt_mask_file is None:
359
+ return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file."
360
+
361
+ # Process image and get mask
362
+ result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
363
+
364
+ if result is None or pred_mask is None:
365
+ return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details."
366
+
367
+ # Compare with ground truth
368
+ comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file)
369
+
370
+ if comparison_path:
371
+ status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}"
372
+ return result, comparison_path, dice_score, iou_score, status
373
+ else:
374
+ return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed."
375
+
376
  def process_sequence(image_files, prompt_text, modality, window_type):
377
  """Process multiple images from the same subject and return gallery of results."""
378
  if model is None or processor is None:
 
539
  interactive=False,
540
  lines=5
541
  )
542
+
543
+ with gr.Tab("Compare with Ground Truth"):
544
+ gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**")
545
+ with gr.Row():
546
+ with gr.Column():
547
+ file_input_gt = gr.File(
548
+ label="Upload Medical Image (DICOM .dcm, PNG, JPG)",
549
+ file_types=[".dcm", ".png", ".jpg", ".jpeg"],
550
+ type="filepath"
551
+ )
552
+
553
+ gt_mask_input = gr.File(
554
+ label="Upload Ground Truth Mask (PNG, JPG)",
555
+ file_types=[".png", ".jpg", ".jpeg"],
556
+ type="filepath",
557
+ info="Upload the ground truth segmentation mask (binary or labeled image)"
558
+ )
559
+
560
+ text_input_gt = gr.Textbox(
561
+ label="Text Prompt",
562
+ value="brain",
563
+ placeholder="e.g. brain, tumor, skull",
564
+ info="Describe what anatomical structure or region you want to segment"
565
+ )
566
+
567
+ with gr.Row():
568
+ modality_dropdown_gt = gr.Dropdown(
569
+ ["CT", "MRI"],
570
+ label="Modality",
571
+ value="MRI",
572
+ info="Select the imaging modality"
573
+ )
574
+ window_dropdown_gt = gr.Dropdown(
575
+ ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
576
+ label="Windowing Strategy (CT only)",
577
+ value="Brain (Grey Matter)",
578
+ info="CT windowing preset (ignored for MRI)"
579
+ )
580
+
581
+ submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg")
582
+
583
+ with gr.Column():
584
+ image_output_gt = gr.Image(
585
+ label="SAM 3 Segmentation",
586
+ type="filepath"
587
+ )
588
+
589
+ comparison_output = gr.Image(
590
+ label="Comparison: SAM 3 vs Ground Truth",
591
+ type="filepath"
592
+ )
593
+
594
+ gr.Markdown("### Metrics")
595
+ dice_score_text = gr.Textbox(
596
+ label="Dice Score",
597
+ value="--",
598
+ interactive=False
599
+ )
600
+
601
+ iou_score_text = gr.Textbox(
602
+ label="IoU Score",
603
+ value="--",
604
+ interactive=False
605
+ )
606
+
607
+ gr.Markdown("### Status")
608
+ status_gt_text = gr.Textbox(
609
+ label="Processing Status",
610
+ value="Ready. Upload image and ground truth mask to compare.",
611
+ interactive=False,
612
+ lines=3
613
+ )
614
 
615
  # Single image processing
616
  load_demo_btn.click(
 
631
  inputs=[files_input, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
632
  outputs=[gallery_output, status_batch_text]
633
  )
634
+
635
+ # Ground truth comparison
636
+ submit_gt_btn.click(
637
+ fn=process_with_ground_truth,
638
+ inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt],
639
+ outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text]
640
+ )
641
 
642
  if __name__ == "__main__":
643
  demo.launch()