Add ground truth comparison: compare SAM 3 segmentation with ground truth masks (BraTS, Kaggle datasets)
Browse files
app.py
CHANGED
|
@@ -93,8 +93,73 @@ except:
|
|
| 93 |
except Exception as e:
|
| 94 |
print(f"⚠️ Could not create demo file: {e}")
|
| 95 |
|
| 96 |
-
def
|
| 97 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if model is None or processor is None:
|
| 99 |
print("❌ Error: Model not loaded.")
|
| 100 |
return None
|
|
@@ -223,6 +288,7 @@ def process_medical_image(image_file, prompt_text, modality, window_type):
|
|
| 223 |
plt.figure(figsize=(10, 10))
|
| 224 |
plt.imshow(pil_image)
|
| 225 |
|
|
|
|
| 226 |
if 'masks' in results and results['masks'] is not None:
|
| 227 |
masks = results['masks'].cpu().numpy()
|
| 228 |
if len(masks) > 0:
|
|
@@ -243,6 +309,8 @@ def process_medical_image(image_file, prompt_text, modality, window_type):
|
|
| 243 |
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
|
| 244 |
plt.close()
|
| 245 |
|
|
|
|
|
|
|
| 246 |
return output_path
|
| 247 |
|
| 248 |
except pydicom.errors.InvalidDicomError as e:
|
|
@@ -279,6 +347,32 @@ def process_with_status(image_file, prompt_text, modality, window_type):
|
|
| 279 |
else:
|
| 280 |
return result, "✅ Segmentation complete!"
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
def process_sequence(image_files, prompt_text, modality, window_type):
|
| 283 |
"""Process multiple images from the same subject and return gallery of results."""
|
| 284 |
if model is None or processor is None:
|
|
@@ -445,6 +539,78 @@ with gr.Blocks() as demo:
|
|
| 445 |
interactive=False,
|
| 446 |
lines=5
|
| 447 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
# Single image processing
|
| 450 |
load_demo_btn.click(
|
|
@@ -465,6 +631,13 @@ with gr.Blocks() as demo:
|
|
| 465 |
inputs=[files_input, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
|
| 466 |
outputs=[gallery_output, status_batch_text]
|
| 467 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
if __name__ == "__main__":
|
| 470 |
demo.launch()
|
|
|
|
| 93 |
except Exception as e:
|
| 94 |
print(f"⚠️ Could not create demo file: {e}")
|
| 95 |
|
| 96 |
+
def compare_with_ground_truth(pred_mask, gt_mask_path):
|
| 97 |
+
"""Compare SAM 3 prediction with ground truth mask and return comparison metrics."""
|
| 98 |
+
try:
|
| 99 |
+
gt_mask = Image.open(gt_mask_path)
|
| 100 |
+
gt_array = np.array(gt_mask.convert('L')) > 127 # Binarize
|
| 101 |
+
|
| 102 |
+
# Resize prediction mask to match ground truth if needed
|
| 103 |
+
if pred_mask.shape != gt_array.shape:
|
| 104 |
+
from PIL import Image as PILImage
|
| 105 |
+
pred_pil = PILImage.fromarray((pred_mask * 255).astype(np.uint8))
|
| 106 |
+
pred_pil = pred_pil.resize(gt_mask.size, PILImage.NEAREST)
|
| 107 |
+
pred_mask = np.array(pred_pil) > 127
|
| 108 |
+
|
| 109 |
+
# Calculate metrics
|
| 110 |
+
intersection = np.logical_and(pred_mask, gt_array).sum()
|
| 111 |
+
union = np.logical_or(pred_mask, gt_array).sum()
|
| 112 |
+
dice_score = (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0
|
| 113 |
+
iou_score = intersection / union if union > 0 else 0.0
|
| 114 |
+
|
| 115 |
+
# Create comparison visualization
|
| 116 |
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 117 |
+
|
| 118 |
+
axes[0].imshow(pred_mask, cmap='spring')
|
| 119 |
+
axes[0].set_title('SAM 3 Prediction')
|
| 120 |
+
axes[0].axis('off')
|
| 121 |
+
|
| 122 |
+
axes[1].imshow(gt_array, cmap='cool')
|
| 123 |
+
axes[1].set_title('Ground Truth')
|
| 124 |
+
axes[1].axis('off')
|
| 125 |
+
|
| 126 |
+
# Overlay comparison
|
| 127 |
+
comparison = np.zeros((*pred_mask.shape, 3))
|
| 128 |
+
comparison[pred_mask & gt_array] = [0, 1, 0] # Green: True Positive
|
| 129 |
+
comparison[pred_mask & ~gt_array] = [1, 0, 0] # Red: False Positive
|
| 130 |
+
comparison[~pred_mask & gt_array] = [0, 0, 1] # Blue: False Negative
|
| 131 |
+
|
| 132 |
+
axes[2].imshow(comparison)
|
| 133 |
+
axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}')
|
| 134 |
+
axes[2].axis('off')
|
| 135 |
+
|
| 136 |
+
plt.tight_layout()
|
| 137 |
+
|
| 138 |
+
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 139 |
+
output_path = output_file.name
|
| 140 |
+
output_file.close()
|
| 141 |
+
|
| 142 |
+
plt.savefig(output_path, bbox_inches='tight', dpi=100)
|
| 143 |
+
plt.close()
|
| 144 |
+
|
| 145 |
+
return output_path, dice_score, iou_score
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"⚠️ Error comparing with ground truth: {e}")
|
| 148 |
+
return None, 0.0, 0.0
|
| 149 |
+
|
| 150 |
+
def process_medical_image(image_file, prompt_text, modality, window_type, return_mask=False):
|
| 151 |
+
"""Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
image_file: Path to image file
|
| 155 |
+
prompt_text: Text prompt for segmentation
|
| 156 |
+
modality: CT or MRI
|
| 157 |
+
window_type: Windowing strategy
|
| 158 |
+
return_mask: If True, also return the binary mask array
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Path to output image, and optionally the mask array
|
| 162 |
+
"""
|
| 163 |
if model is None or processor is None:
|
| 164 |
print("❌ Error: Model not loaded.")
|
| 165 |
return None
|
|
|
|
| 288 |
plt.figure(figsize=(10, 10))
|
| 289 |
plt.imshow(pil_image)
|
| 290 |
|
| 291 |
+
final_mask = None
|
| 292 |
if 'masks' in results and results['masks'] is not None:
|
| 293 |
masks = results['masks'].cpu().numpy()
|
| 294 |
if len(masks) > 0:
|
|
|
|
| 309 |
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
|
| 310 |
plt.close()
|
| 311 |
|
| 312 |
+
if return_mask:
|
| 313 |
+
return output_path, final_mask
|
| 314 |
return output_path
|
| 315 |
|
| 316 |
except pydicom.errors.InvalidDicomError as e:
|
|
|
|
| 347 |
else:
|
| 348 |
return result, "✅ Segmentation complete!"
|
| 349 |
|
| 350 |
+
def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type):
|
| 351 |
+
"""Process image and compare with ground truth segmentation mask."""
|
| 352 |
+
if model is None or processor is None:
|
| 353 |
+
return None, None, 0.0, 0.0, "❌ Error: Model not loaded."
|
| 354 |
+
|
| 355 |
+
if image_file is None:
|
| 356 |
+
return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file."
|
| 357 |
+
|
| 358 |
+
if gt_mask_file is None:
|
| 359 |
+
return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file."
|
| 360 |
+
|
| 361 |
+
# Process image and get mask
|
| 362 |
+
result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
|
| 363 |
+
|
| 364 |
+
if result is None or pred_mask is None:
|
| 365 |
+
return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details."
|
| 366 |
+
|
| 367 |
+
# Compare with ground truth
|
| 368 |
+
comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file)
|
| 369 |
+
|
| 370 |
+
if comparison_path:
|
| 371 |
+
status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}"
|
| 372 |
+
return result, comparison_path, dice_score, iou_score, status
|
| 373 |
+
else:
|
| 374 |
+
return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed."
|
| 375 |
+
|
| 376 |
def process_sequence(image_files, prompt_text, modality, window_type):
|
| 377 |
"""Process multiple images from the same subject and return gallery of results."""
|
| 378 |
if model is None or processor is None:
|
|
|
|
| 539 |
interactive=False,
|
| 540 |
lines=5
|
| 541 |
)
|
| 542 |
+
|
| 543 |
+
with gr.Tab("Compare with Ground Truth"):
|
| 544 |
+
gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**")
|
| 545 |
+
with gr.Row():
|
| 546 |
+
with gr.Column():
|
| 547 |
+
file_input_gt = gr.File(
|
| 548 |
+
label="Upload Medical Image (DICOM .dcm, PNG, JPG)",
|
| 549 |
+
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
|
| 550 |
+
type="filepath"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
gt_mask_input = gr.File(
|
| 554 |
+
label="Upload Ground Truth Mask (PNG, JPG)",
|
| 555 |
+
file_types=[".png", ".jpg", ".jpeg"],
|
| 556 |
+
type="filepath",
|
| 557 |
+
info="Upload the ground truth segmentation mask (binary or labeled image)"
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
text_input_gt = gr.Textbox(
|
| 561 |
+
label="Text Prompt",
|
| 562 |
+
value="brain",
|
| 563 |
+
placeholder="e.g. brain, tumor, skull",
|
| 564 |
+
info="Describe what anatomical structure or region you want to segment"
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
with gr.Row():
|
| 568 |
+
modality_dropdown_gt = gr.Dropdown(
|
| 569 |
+
["CT", "MRI"],
|
| 570 |
+
label="Modality",
|
| 571 |
+
value="MRI",
|
| 572 |
+
info="Select the imaging modality"
|
| 573 |
+
)
|
| 574 |
+
window_dropdown_gt = gr.Dropdown(
|
| 575 |
+
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
|
| 576 |
+
label="Windowing Strategy (CT only)",
|
| 577 |
+
value="Brain (Grey Matter)",
|
| 578 |
+
info="CT windowing preset (ignored for MRI)"
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg")
|
| 582 |
+
|
| 583 |
+
with gr.Column():
|
| 584 |
+
image_output_gt = gr.Image(
|
| 585 |
+
label="SAM 3 Segmentation",
|
| 586 |
+
type="filepath"
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
comparison_output = gr.Image(
|
| 590 |
+
label="Comparison: SAM 3 vs Ground Truth",
|
| 591 |
+
type="filepath"
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
gr.Markdown("### Metrics")
|
| 595 |
+
dice_score_text = gr.Textbox(
|
| 596 |
+
label="Dice Score",
|
| 597 |
+
value="--",
|
| 598 |
+
interactive=False
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
iou_score_text = gr.Textbox(
|
| 602 |
+
label="IoU Score",
|
| 603 |
+
value="--",
|
| 604 |
+
interactive=False
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
gr.Markdown("### Status")
|
| 608 |
+
status_gt_text = gr.Textbox(
|
| 609 |
+
label="Processing Status",
|
| 610 |
+
value="Ready. Upload image and ground truth mask to compare.",
|
| 611 |
+
interactive=False,
|
| 612 |
+
lines=3
|
| 613 |
+
)
|
| 614 |
|
| 615 |
# Single image processing
|
| 616 |
load_demo_btn.click(
|
|
|
|
| 631 |
inputs=[files_input, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
|
| 632 |
outputs=[gallery_output, status_batch_text]
|
| 633 |
)
|
| 634 |
+
|
| 635 |
+
# Ground truth comparison
|
| 636 |
+
submit_gt_btn.click(
|
| 637 |
+
fn=process_with_ground_truth,
|
| 638 |
+
inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt],
|
| 639 |
+
outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text]
|
| 640 |
+
)
|
| 641 |
|
| 642 |
if __name__ == "__main__":
|
| 643 |
demo.launch()
|