NeuroSAM3 / app.py.backup
mmrech's picture
Refactor codebase: Add modular structure, logging, validation, and comprehensive improvements
69066c5
"""
NeuroSAM 3: Medical Image Segmentation App
A Gradio app for segmenting medical images (CT/MRI) using SAM 3
"""
import os
import tempfile
import zipfile
import io
import json
import time
from datetime import datetime
import gradio as gr
import spaces
import torch
import pydicom
import numpy as np
from PIL import Image, ImageEnhance, ImageDraw
try:
from transformers import Sam3Processor, Sam3Model
SAM3_AVAILABLE = True
except ImportError:
print("⚠️ Warning: Sam3Processor/Sam3Model not found in transformers.")
print("⚠️ SAM3 requires transformers from GitHub main branch.")
print("⚠️ Install with: pip install git+https://github.com/huggingface/transformers.git")
SAM3_AVAILABLE = False
# Create dummy classes to prevent import errors
Sam3Processor = None
Sam3Model = None
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy import ndimage
from huggingface_hub import login
# Try to import nibabel for NIFTI support (optional)
try:
import nibabel as nib
NIBABEL_AVAILABLE = True
except ImportError:
NIBABEL_AVAILABLE = False
print("⚠️ nibabel not available - NIFTI export disabled")
# Hugging Face Token (must be set as HF_TOKEN environment variable in Space settings)
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("⚠️ WARNING: HF_TOKEN environment variable not set!")
print("⚠️ Some features may not work. Please set HF_TOKEN in Space settings.")
hf_token = None # Allow app to start, but model loading will fail gracefully
else:
# Login to Hugging Face Hub (only if token is provided)
try:
login(token=hf_token, add_to_git_credential=False)
except Exception as e:
print(f"⚠️ Could not login to HF Hub (non-critical): {e}")
# Load SAM 3 Model
print("🧠 Loading SAM 3 Model...")
# IMPORTANT: For HF Spaces with Stateless GPU, load model on CPU in main process
# Model will be moved to GPU inside @spaces.GPU decorated functions
model = None
processor = None
if not SAM3_AVAILABLE:
print("❌ SAM 3 classes not available in transformers library.")
print("❌ Install with: pip install git+https://github.com/huggingface/transformers.git")
print("⚠️ App will start but segmentation features will be disabled.")
else:
# SAM 3 model identifier - matching official implementation
SAM_MODEL_ID = "facebook/sam3"
if hf_token is None:
print("⚠️ Cannot load model: HF_TOKEN not set")
model = None
processor = None
else:
try:
# Load model on CPU to avoid CUDA initialization in main process (for HF Spaces Stateless GPU)
# Model will be moved to GPU inside @spaces.GPU decorated functions
model = Sam3Model.from_pretrained(
SAM_MODEL_ID,
torch_dtype=torch.float32, # Load as float32 on CPU
token=hf_token
)
processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=hf_token)
model.eval()
print(f"✅ SAM 3 Model Loaded Successfully on CPU! ({SAM_MODEL_ID})")
print("💡 Model will be moved to GPU when inference is called")
except Exception as e:
print(f"⚠️ Failed to load SAM 3 model: {e}")
print("Ensure you have:")
print(" 1. transformers from GitHub main branch for SAM 3 support")
print(" Install with: pip install git+https://github.com/huggingface/transformers.git")
print(" 2. Valid Hugging Face token with access to SAM 3")
print(" 3. Sufficient memory for the model")
print("⚠️ App will start but segmentation features will be disabled until model loads.")
# Don't raise - allow app to start and show error in UI
model = None
processor = None
@spaces.GPU(duration=60)
def run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0):
"""
Run SAM 3 inference - optimized for medical imaging.
Args:
pil_image: PIL Image to segment
prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull")
threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images).
Lower values (0.0-0.3) are more permissive and better for subtle features.
Higher values (0.5-1.0) require high confidence, may miss detections.
mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images).
Lower values preserve more detail. Higher values create sharper masks.
Medical images often benefit from 0.0 to capture subtle boundaries.
Returns:
results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed
Note:
Default thresholds (0.1, 0.0) are optimized for medical imaging where features
may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5)
may be more appropriate.
"""
if model is None or processor is None:
print("❌ Model not loaded - please check HF_TOKEN and model availability")
raise ValueError("SAM 3 model not loaded. Please check that HF_TOKEN is set correctly and the model is accessible.")
def to_serializable(obj):
"""
Convert all tensors to numpy arrays or Python primitives for safe serialization.
This ensures NO PyTorch tensors (CPU or CUDA) are in the return value.
"""
if isinstance(obj, torch.Tensor):
# Convert to numpy array (works for both CPU and CUDA tensors)
result = obj.cpu().numpy()
print(f"🔄 Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}")
return result
elif isinstance(obj, dict):
return {k: to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_serializable(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(to_serializable(item) for item in obj)
elif isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif hasattr(obj, 'item'): # numpy scalar
return obj.item()
else:
# For unknown types, try to convert to string representation
print(f"⚠️ Unknown type encountered: {type(obj)}, converting to string")
return str(obj)
try:
# Determine device and move model to GPU if available (CUDA initialization happens here, inside @spaces.GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Using device: {device}")
# Move model to device and set appropriate dtype
# Note: For nn.Module, .to() modifies in-place and returns self
# IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued and processed
# one at a time, so there's NO concurrent access to the model. This makes in-place
# modification safe despite model being a global variable.
dtype = torch.float16 if device == "cuda" else torch.float32
model.to(device=device, dtype=dtype)
print(f"✅ Model moved to {device} with dtype {dtype}")
# Prepare inputs - matching official implementation
inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device)
# Convert float32 inputs to model dtype (float16 for GPU) - matching official implementation
for key in inputs:
if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32:
inputs[key] = inputs[key].to(model.dtype)
with torch.no_grad():
outputs = model(**inputs)
print(f"🧠 Inference complete, processing results...")
# Post-process using processor method - matching official implementation
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=mask_threshold,
target_sizes=inputs.get("original_sizes").tolist() if "original_sizes" in inputs else [pil_image.size[::-1]]
)[0] # Get first batch result
print(f"📊 Results type: {type(results)}")
if isinstance(results, dict):
print(f"📊 Results keys: {results.keys()}")
for key, value in results.items():
print(f" - {key}: type={type(value)}")
if isinstance(value, torch.Tensor):
print(f" tensor device={value.device}, shape={value.shape}, dtype={value.dtype}")
elif isinstance(value, list) and len(value) > 0:
print(f" list length={len(value)}, first item type={type(value[0])}")
if isinstance(value[0], torch.Tensor):
print(f" first tensor device={value[0].device}")
# CRITICAL: Convert ALL tensors to numpy arrays before returning
# This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary
# Numpy arrays are safely serializable without triggering CUDA init
print(f"🔄 Converting all tensors to numpy arrays...")
results = to_serializable(results)
print(f"✅ All tensors converted to serializable format")
# Move model back to CPU to free GPU memory (important for Spaces)
model.to("cpu")
print(f"✅ Model moved back to CPU")
return results
except Exception as e:
print(f"❌ Error during SAM 3 inference: {e}")
import traceback
traceback.print_exc()
# Make sure to move model back to CPU even on error
if model is not None:
try:
model.to("cpu")
except RuntimeError as cleanup_error:
print(f"⚠️ Could not move model back to CPU: {cleanup_error}")
return None
# Create Sample DICOM File for Demo
demo_dicom_path = "demo_brain_mri.dcm"
demo_file_available = False
try:
from pydicom.data import get_testdata_file
test_file = get_testdata_file("MR_small.dcm")
if test_file and os.path.exists(test_file):
import shutil
shutil.copy(test_file, demo_dicom_path)
demo_file_available = True
print(f"✅ Demo file ready: {demo_dicom_path}")
except:
try:
# Create synthetic DICOM file
from pydicom.dataset import FileDataset, FileMetaDataset
from pydicom.uid import generate_uid
from datetime import datetime
synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16)
center_x, center_y = 128, 128
y, x = np.ogrid[:256, :256]
mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2
synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255)
file_meta = FileMetaDataset()
file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4'
file_meta.MediaStorageSOPInstanceUID = generate_uid()
file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1'
ds = FileDataset(demo_dicom_path, {}, file_meta=file_meta, preamble=b"\x00" * 128)
ds.PatientName = "Demo^Patient"
ds.PatientID = "DEMO001"
ds.Modality = "MR"
ds.Rows = 256
ds.Columns = 256
ds.BitsAllocated = 16
ds.BitsStored = 16
ds.HighBit = 15
ds.SamplesPerPixel = 1
ds.PixelRepresentation = 0
ds.PhotometricInterpretation = "MONOCHROME2"
ds.PixelSpacing = [1.0, 1.0]
ds.RescaleIntercept = "0"
ds.RescaleSlope = "1"
ds.PixelData = synthetic_image.tobytes()
ds.save_as(demo_dicom_path, write_like_original=False)
demo_file_available = True
print(f"✅ Synthetic demo file created: {demo_dicom_path}")
except Exception as e:
print(f"⚠️ Could not create demo file: {e}")
def compare_with_ground_truth(pred_mask, gt_mask_path):
"""Compare SAM 3 prediction with ground truth mask and return comparison metrics."""
try:
gt_mask = Image.open(gt_mask_path)
gt_array = np.array(gt_mask.convert('L')) > 127 # Binarize
# Resize prediction mask to match ground truth if needed
if pred_mask.shape != gt_array.shape:
from PIL import Image as PILImage
pred_pil = PILImage.fromarray((pred_mask * 255).astype(np.uint8))
pred_pil = pred_pil.resize(gt_mask.size, PILImage.NEAREST)
pred_mask = np.array(pred_pil) > 127
# Calculate metrics
intersection = np.logical_and(pred_mask, gt_array).sum()
union = np.logical_or(pred_mask, gt_array).sum()
dice_score = (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0
iou_score = intersection / union if union > 0 else 0.0
# Create comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(pred_mask, cmap='spring')
axes[0].set_title('SAM 3 Prediction')
axes[0].axis('off')
axes[1].imshow(gt_array, cmap='cool')
axes[1].set_title('Ground Truth')
axes[1].axis('off')
# Overlay comparison
comparison = np.zeros((*pred_mask.shape, 3))
comparison[pred_mask & gt_array] = [0, 1, 0] # Green: True Positive
comparison[pred_mask & ~gt_array] = [1, 0, 0] # Red: False Positive
comparison[~pred_mask & gt_array] = [0, 0, 1] # Blue: False Negative
axes[2].imshow(comparison)
axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}')
axes[2].axis('off')
plt.tight_layout()
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', dpi=100)
plt.close()
return output_path, dice_score, iou_score
except Exception as e:
print(f"⚠️ Error comparing with ground truth: {e}")
return None, 0.0, 0.0
def process_medical_image(image_file, prompt_text, modality, window_type, return_mask=False):
"""Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3.
Args:
image_file: Path to image file
prompt_text: Text prompt for segmentation
modality: CT or MRI
window_type: Windowing strategy
return_mask: If True, also return the binary mask array
Returns:
Path to output image, and optionally the mask array
"""
if model is None or processor is None:
print("❌ Error: Model not loaded.")
return None
if image_file is None:
return None
if not prompt_text or not prompt_text.strip():
prompt_text = "brain"
try:
file_path = image_file if isinstance(image_file, str) else str(image_file)
if not os.path.exists(file_path):
print(f"❌ Error: File not found at {file_path}")
return None
# Detect file type
file_ext = os.path.splitext(file_path)[1].lower()
is_dicom = file_ext == '.dcm'
if is_dicom:
# Process DICOM file
ds = pydicom.dcmread(file_path)
if not hasattr(ds, 'pixel_array'):
print("❌ Error: DICOM file does not contain pixel data.")
return None
raw = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_hu = raw * slope + intercept
# Apply Windowing
if modality == "CT":
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else: # MRI
img_min = np.percentile(img_hu, 1)
img_max = np.percentile(img_hu, 99)
img_range = img_max - img_min
if img_range <= 0:
img_min = np.min(img_hu)
img_max = np.max(img_hu)
img_range = img_max - img_min
if img_range <= 0:
return None
img_windowed = (img_hu - img_min) / img_range
img_windowed = np.clip(img_windowed, 0, 1)
img_uint8 = (img_windowed * 255).astype(np.uint8)
if len(img_uint8.shape) == 2:
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.fromarray(img_uint8)
else:
# Process standard image file (PNG, JPG, etc.)
pil_image = Image.open(file_path)
# Convert to RGB if needed
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
# Convert to numpy for normalization
img_array = np.array(pil_image)
# Handle grayscale images
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize image (percentile-based for MRI-like processing)
img_float = img_array.astype(np.float32)
if modality == "CT":
# For CT-like processing, use windowing
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else: # MRI - use percentile normalization
img_min = np.percentile(img_float, 1)
img_max = np.percentile(img_float, 99)
img_range = img_max - img_min
if img_range <= 0:
img_min = np.min(img_float)
img_max = np.max(img_float)
img_range = img_max - img_min
if img_range <= 0:
return None
img_normalized = (img_float - img_min) / img_range
img_normalized = np.clip(img_normalized, 0, 1)
img_uint8 = (img_normalized * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8.astype(np.uint8))
# Run SAM 3 Inference - using helper function matching official implementation
# Lower thresholds for medical images to ensure detections are not filtered out
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
if results is None:
return None
# Draw Masks on Image - matching official implementation format
plt.figure(figsize=(10, 10))
plt.imshow(pil_image)
final_mask = None
if 'masks' in results and results['masks'] is not None:
masks = results['masks'] # List of mask tensors from post_process_instance_segmentation
scores = results.get('scores', [])
if len(masks) > 0:
# Combine all masks into one (or use first mask)
# Convert tensors to numpy and combine
mask_arrays = []
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
mask_arrays.append(mask_np)
# Combine all masks
if len(mask_arrays) > 0:
final_mask = np.any(mask_arrays, axis=0)
plt.imshow(final_mask, alpha=0.5, cmap='spring')
else:
print("⚠️ Warning: No valid masks found.")
else:
print("⚠️ Warning: No masks in results.")
else:
print("⚠️ Warning: No masks in results.")
plt.axis('off')
plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
if return_mask:
return output_path, final_mask
return output_path
except pydicom.errors.InvalidDicomError as e:
print(f"❌ Error: Invalid DICOM file format. {e}")
return None
except Exception as e:
print(f"❌ Error processing image: {e}")
import traceback
traceback.print_exc()
return None
def process_medical_image_enhanced(image_file, prompt_text, modality, window_type,
brightness=1.0, contrast=1.0, colormap='spring',
transparency=0.5, return_mask=False):
"""Enhanced version of process_medical_image with image adjustments and visualization options.
Args:
image_file: Path to image file
prompt_text: Text prompt for segmentation
modality: CT or MRI
window_type: Windowing strategy
brightness: Brightness multiplier (0.5-2.0)
contrast: Contrast multiplier (0.5-2.0)
colormap: Matplotlib colormap name
transparency: Mask overlay transparency (0.0-1.0)
return_mask: If True, also return the binary mask array
Returns:
Path to output image, and optionally the mask array
"""
if model is None or processor is None:
print("❌ Error: Model not loaded.")
return None
if image_file is None:
return None
if not prompt_text or not prompt_text.strip():
prompt_text = "brain"
try:
file_path = image_file if isinstance(image_file, str) else str(image_file)
if not os.path.exists(file_path):
print(f"❌ Error: File not found at {file_path}")
return None
# Detect file type
file_ext = os.path.splitext(file_path)[1].lower()
is_dicom = file_ext == '.dcm'
if is_dicom:
# Process DICOM file
ds = pydicom.dcmread(file_path)
if not hasattr(ds, 'pixel_array'):
print("❌ Error: DICOM file does not contain pixel data.")
return None
raw = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_hu = raw * slope + intercept
# Apply Windowing
if modality == "CT":
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else: # MRI
img_min = np.percentile(img_hu, 1)
img_max = np.percentile(img_hu, 99)
img_range = img_max - img_min
if img_range <= 0:
img_min = np.min(img_hu)
img_max = np.max(img_hu)
img_range = img_max - img_min
if img_range <= 0:
return None
img_windowed = (img_hu - img_min) / img_range
img_windowed = np.clip(img_windowed, 0, 1)
img_uint8 = (img_windowed * 255).astype(np.uint8)
if len(img_uint8.shape) == 2:
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.fromarray(img_uint8)
else:
# Process standard image file (PNG, JPG, etc.)
pil_image = Image.open(file_path)
# Convert to RGB if needed
if pil_image.mode != 'RGB':
pil_image = pil_image.convert('RGB')
# Convert to numpy for normalization
img_array = np.array(pil_image)
# Handle grayscale images
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize image (percentile-based for MRI-like processing)
img_float = img_array.astype(np.float32)
if modality == "CT":
# For CT-like processing, use windowing
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else: # MRI - use percentile normalization
img_min = np.percentile(img_float, 1)
img_max = np.percentile(img_float, 99)
img_range = img_max - img_min
if img_range <= 0:
img_min = np.min(img_float)
img_max = np.max(img_float)
img_range = img_max - img_min
if img_range <= 0:
return None
img_normalized = (img_float - img_min) / img_range
img_normalized = np.clip(img_normalized, 0, 1)
img_uint8 = (img_normalized * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8.astype(np.uint8))
# Apply brightness and contrast adjustments
enhancer = ImageEnhance.Brightness(pil_image)
pil_image = enhancer.enhance(brightness)
enhancer = ImageEnhance.Contrast(pil_image)
pil_image = enhancer.enhance(contrast)
# Run SAM 3 Inference - using helper function matching official implementation
# Lower thresholds for medical images to ensure detections are not filtered out
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
if results is None:
return None
# Draw Masks on Image with enhanced visualization - matching official implementation format
plt.figure(figsize=(10, 10))
plt.imshow(pil_image)
final_mask = None
if 'masks' in results and results['masks'] is not None:
masks = results['masks'] # List of mask tensors from post_process_instance_segmentation
scores = results.get('scores', [])
if len(masks) > 0:
# Combine all masks into one
mask_arrays = []
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
mask_arrays.append(mask_np)
# Combine all masks
if len(mask_arrays) > 0:
final_mask = np.any(mask_arrays, axis=0)
plt.imshow(final_mask, alpha=transparency, cmap=colormap)
else:
print("⚠️ Warning: No valid masks found.")
else:
print("⚠️ Warning: No masks in results.")
else:
print("⚠️ Warning: No masks in results.")
plt.axis('off')
plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
if return_mask:
return output_path, final_mask
return output_path
except pydicom.errors.InvalidDicomError as e:
print(f"❌ Error: Invalid DICOM file format. {e}")
return None
except Exception as e:
print(f"❌ Error processing image: {e}")
import traceback
traceback.print_exc()
return None
def process_with_progress(image_file, prompt_text, modality, window_type,
brightness=1.0, contrast=1.0, colormap='spring',
transparency=0.5, progress=gr.Progress()):
"""Process with progress indicator."""
if model is None or processor is None:
return None, "❌ Error: Model not loaded.", ""
if image_file is None:
return None, "⚠️ Please upload a medical image file.", ""
progress(0, desc="Starting...")
progress(0.1, desc="Reading image...")
result = process_medical_image_enhanced(
image_file, prompt_text, modality, window_type,
brightness, contrast, colormap, transparency
)
progress(0.8, desc="Processing segmentation...")
if result is None:
progress(1.0, desc="Failed!")
return None, "❌ Processing failed. Check console for error details.", ""
progress(1.0, desc="Complete!")
# Calculate metrics
metrics = f"Prompt: {prompt_text}\nModality: {modality}\nProcessed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
return result, "✅ Segmentation complete!", metrics
def create_downloadable_result(output_path, prompt_text):
"""Create a downloadable file with metadata."""
return output_path # Gradio File component handles downloads automatically
def process_batch_enhanced(image_files, prompt_text, modality, window_type,
brightness=1.0, contrast=1.0, colormap='spring',
transparency=0.5, progress=gr.Progress()):
"""Process multiple images with enhanced features and create ZIP download."""
if model is None or processor is None:
return [], None, "❌ Error: Model not loaded."
if not image_files:
return [], None, "⚠️ Please upload medical image files."
# Handle single file or list of files
if isinstance(image_files, str):
image_files = [image_files]
results = []
total = len(image_files)
for idx, image_file in enumerate(image_files):
if image_file is None:
continue
progress((idx + 1) / total, desc=f"Processing image {idx + 1}/{total}...")
result = process_medical_image_enhanced(
image_file, prompt_text, modality, window_type,
brightness, contrast, colormap, transparency
)
if result:
results.append(result)
if not results:
return [], None, "❌ No images were processed successfully."
# Create ZIP file with all results
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for idx, result_path in enumerate(results):
if os.path.exists(result_path):
filename = f"segmentation_{idx + 1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
zip_file.write(result_path, filename)
zip_buffer.seek(0)
zip_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
zip_path.write(zip_buffer.read())
zip_path.close()
status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download."
return results, zip_path.name, status
# ============================================================================
# ENHANCED FEATURES - Auto-play, Point/Box Prompts, ROI Stats, NIFTI Export
# ============================================================================
# Global state for auto-play
auto_play_state = {"running": False, "current_idx": 0}
def calculate_roi_statistics(image_file, mask, modality):
"""Calculate ROI statistics from the segmented region.
Returns:
dict: Statistics including area, mean intensity, std, min, max, centroid
"""
if mask is None or not isinstance(mask, np.ndarray):
return {
"error": "No valid mask available",
"area_pixels": 0,
"area_percentage": 0,
"mean_intensity": 0,
"std_intensity": 0,
"min_intensity": 0,
"max_intensity": 0,
"centroid": (0, 0),
"bounding_box": (0, 0, 0, 0)
}
try:
# Load original image for intensity statistics
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
else:
img = Image.open(file_path)
if img.mode == 'RGB':
img = img.convert('L') # Convert to grayscale for intensity stats
img_array = np.array(img).astype(np.float32)
# Resize mask if needed
if mask.shape != img_array.shape:
from scipy.ndimage import zoom
zoom_factors = (img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1])
mask = zoom(mask.astype(float), zoom_factors, order=0) > 0.5
# Calculate statistics
mask_bool = mask.astype(bool)
total_pixels = mask.size
roi_pixels = np.sum(mask_bool)
if roi_pixels == 0:
return {
"error": "No pixels in ROI",
"area_pixels": 0,
"area_percentage": 0,
"mean_intensity": 0,
"std_intensity": 0,
"min_intensity": 0,
"max_intensity": 0,
"centroid": (0, 0),
"bounding_box": (0, 0, 0, 0)
}
roi_intensities = img_array[mask_bool]
# Calculate centroid
labeled_mask, num_features = ndimage.label(mask_bool)
centroid = ndimage.center_of_mass(mask_bool)
# Calculate bounding box
rows = np.any(mask_bool, axis=1)
cols = np.any(mask_bool, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
stats = {
"area_pixels": int(roi_pixels),
"area_percentage": float(roi_pixels / total_pixels * 100),
"mean_intensity": float(np.mean(roi_intensities)),
"std_intensity": float(np.std(roi_intensities)),
"min_intensity": float(np.min(roi_intensities)),
"max_intensity": float(np.max(roi_intensities)),
"centroid": (float(centroid[1]), float(centroid[0])), # (x, y)
"bounding_box": (int(cmin), int(rmin), int(cmax), int(rmax)), # (x1, y1, x2, y2)
"num_components": num_features
}
# Add HU statistics for CT
if modality == "CT":
stats["mean_hu"] = stats["mean_intensity"]
stats["std_hu"] = stats["std_intensity"]
return stats
except Exception as e:
print(f"Error calculating ROI statistics: {e}")
return {"error": str(e)}
def format_roi_statistics(stats):
"""Format ROI statistics as a readable string."""
if "error" in stats and stats.get("area_pixels", 0) == 0:
return f"⚠️ {stats.get('error', 'No statistics available')}"
text = "📊 **ROI Statistics**\n\n"
text += f"**Area:** {stats['area_pixels']:,} pixels ({stats['area_percentage']:.2f}%)\n"
text += f"**Intensity:** {stats['mean_intensity']:.2f} ± {stats['std_intensity']:.2f}\n"
text += f"**Range:** [{stats['min_intensity']:.2f}, {stats['max_intensity']:.2f}]\n"
text += f"**Centroid:** ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f})\n"
text += f"**Bounding Box:** {stats['bounding_box']}\n"
text += f"**Components:** {stats.get('num_components', 1)}"
if "mean_hu" in stats:
text += f"\n\n**CT (Hounsfield Units):**\n"
text += f"Mean HU: {stats['mean_hu']:.1f} ± {stats['std_hu']:.1f}"
return text
def process_with_roi_stats(image_file, prompt_text, modality, window_type):
"""Process image and return both segmentation and ROI statistics."""
if model is None or processor is None:
return None, "❌ Error: Model not loaded.", ""
if image_file is None:
return None, "⚠️ Please upload a medical image file.", ""
result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
if result is None:
return None, "❌ Processing failed.", ""
# Calculate ROI statistics
stats = calculate_roi_statistics(image_file, mask, modality)
stats_text = format_roi_statistics(stats)
return result, "✅ Segmentation complete!", stats_text
def process_with_point_prompt(image_file, point_x, point_y, modality, window_type, colormap='spring', transparency=0.5):
"""Process image with a point prompt for segmentation.
Note: This simulates point-based prompting by using the point location
as a seed for region-based segmentation.
"""
if model is None or processor is None:
return None, "❌ Error: Model not loaded."
if image_file is None:
return None, "⚠️ Please upload a medical image file."
try:
# Load image
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
# Normalize
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.open(file_path).convert('RGB')
img_array = np.array(pil_image)
h, w = img_array.shape[:2]
# Clamp point coordinates
point_x = max(0, min(int(point_x), w - 1))
point_y = max(0, min(int(point_y), h - 1))
# Create a prompt based on the point location
prompt_text = f"segment region at point"
# Process with SAM 3 - using helper function
# Lower thresholds for medical images to ensure detections are not filtered out
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
final_mask = None
if results and 'masks' in results and results['masks'] is not None:
masks = results['masks']
# Select mask containing the point
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
# Resize to image size
mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
if mask_resized[point_y, point_x]:
final_mask = mask_resized
break
# If no mask contains the point, use first mask
if final_mask is None and len(masks) > 0:
mask = masks[0]
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
final_mask = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
# Draw result with point marker
plt.figure(figsize=(10, 10))
plt.imshow(pil_image)
if final_mask is not None:
plt.imshow(final_mask, alpha=transparency, cmap=colormap)
# Draw point marker
plt.scatter([point_x], [point_y], c='red', s=200, marker='+', linewidths=3)
plt.scatter([point_x], [point_y], c='red', s=100, marker='o', facecolors='none', linewidths=2)
plt.axis('off')
plt.title(f"Point Prompt Segmentation at ({point_x}, {point_y})", fontsize=12)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
return output_path, f"✅ Point-based segmentation at ({point_x}, {point_y})"
except Exception as e:
print(f"Error in point prompt processing: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, colormap='spring', transparency=0.5):
"""Process image with a bounding box prompt for segmentation."""
if model is None or processor is None:
return None, "❌ Error: Model not loaded."
if image_file is None:
return None, "⚠️ Please upload a medical image file."
try:
# Load image
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.open(file_path).convert('RGB')
img_array = np.array(pil_image)
h, w = img_array.shape[:2]
# Ensure box coordinates are valid
x1, x2 = min(x1, x2), max(x1, x2)
y1, y2 = min(y1, y2), max(y1, y2)
x1, y1 = max(0, int(x1)), max(0, int(y1))
x2, y2 = min(w, int(x2)), min(h, int(y2))
prompt_text = "segment region in bounding box"
# Process with SAM 3 - using helper function
# Lower thresholds for medical images to ensure detections are not filtered out
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
final_mask = None
if results and 'masks' in results and results['masks'] is not None:
masks = results['masks']
# Combine all masks
mask_arrays = []
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
# Resize to image size
mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127
mask_arrays.append(mask_resized)
if len(mask_arrays) > 0:
combined = np.any(mask_arrays, axis=0)
# Create box mask and intersect
box_mask = np.zeros((h, w), dtype=bool)
box_mask[y1:y2, x1:x2] = True
final_mask = combined & box_mask
# Draw result with box
plt.figure(figsize=(10, 10))
plt.imshow(pil_image)
if final_mask is not None:
plt.imshow(final_mask, alpha=transparency, cmap=colormap)
# Draw bounding box
rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='red', facecolor='none')
plt.gca().add_patch(rect)
plt.axis('off')
plt.title(f"Box Prompt Segmentation [{x1}, {y1}, {x2}, {y2}]", fontsize=12)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
return output_path, f"✅ Box-based segmentation at [{x1}, {y1}, {x2}, {y2}]"
except Exception as e:
print(f"Error in box prompt processing: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks=3):
"""Process image and return multiple mask candidates with confidence scores."""
if model is None or processor is None:
return [], "❌ Error: Model not loaded.", ""
if image_file is None:
return [], "⚠️ Please upload a medical image file.", ""
try:
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.open(file_path).convert('RGB')
if not prompt_text or not prompt_text.strip():
prompt_text = "brain"
# Process with SAM 3 - using helper function
# Lower thresholds for medical images to ensure detections are not filtered out
sam_results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
results = []
mask_info = []
if sam_results and 'masks' in sam_results and sam_results['masks'] is not None:
masks = sam_results['masks'] # List of mask tensors
scores = sam_results.get('scores', []) # List of scores
num_available = len(masks)
num_to_show = min(num_masks, num_available)
colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma']
for i in range(num_to_show):
mask = masks[i]
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
# Convert to boolean
if mask_np.dtype != bool:
mask_np = mask_np > 0.5
score = scores[i].item() if i < len(scores) and isinstance(scores[i], torch.Tensor) else (scores[i] if i < len(scores) else 0.5)
# Create visualization
plt.figure(figsize=(8, 8))
plt.imshow(pil_image)
plt.imshow(mask_np, alpha=0.5, cmap=colormaps[i % len(colormaps)])
plt.axis('off')
plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
results.append(output_path)
mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask_np):,} pixels")
status = f"✅ Generated {len(results)} mask candidate(s)"
info = "\n".join(mask_info) if mask_info else "No mask information available"
return results, status, info
except Exception as e:
print(f"Error in multi-mask processing: {e}")
import traceback
traceback.print_exc()
return [], f"❌ Error: {str(e)}", ""
def export_to_nifti(image_file, mask, output_name="segmentation"):
"""Export segmentation mask to NIFTI format.
Returns:
str: Path to the exported NIFTI file, or None if export failed
"""
if not NIBABEL_AVAILABLE:
return None, "⚠️ NIFTI export not available - nibabel not installed"
if mask is None or not isinstance(mask, np.ndarray):
return None, "⚠️ No valid mask to export"
try:
# Convert mask to appropriate format
mask_data = mask.astype(np.float32)
# Create NIFTI image
# Use identity affine (1mm isotropic)
affine = np.eye(4)
# Try to get spacing from DICOM if available
if image_file:
file_path = image_file if isinstance(image_file, str) else str(image_file)
if file_path.lower().endswith('.dcm'):
try:
ds = pydicom.dcmread(file_path, stop_before_pixels=True)
pixel_spacing = getattr(ds, 'PixelSpacing', [1.0, 1.0])
slice_thickness = getattr(ds, 'SliceThickness', 1.0)
affine[0, 0] = float(pixel_spacing[0])
affine[1, 1] = float(pixel_spacing[1])
affine[2, 2] = float(slice_thickness)
except:
pass
nifti_img = nib.Nifti1Image(mask_data, affine)
# Save to temp file
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz')
output_path = output_file.name
output_file.close()
nib.save(nifti_img, output_path)
return output_path, f"✅ Exported to NIFTI: {output_path}"
except Exception as e:
print(f"Error exporting to NIFTI: {e}")
return None, f"❌ Export failed: {str(e)}"
def save_annotation(image_file, mask, prompt_text, modality, stats=None):
"""Save annotation to a JSON file for later loading."""
if mask is None:
return None, "⚠️ No annotation to save"
try:
annotation = {
"timestamp": datetime.now().isoformat(),
"image_file": os.path.basename(image_file) if image_file else "unknown",
"prompt": prompt_text,
"modality": modality,
"mask_shape": list(mask.shape),
"mask_sum": int(np.sum(mask)),
"mask_base64": None, # We'll store as binary in a separate file
"statistics": stats if stats else {}
}
# Save mask as numpy file
mask_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npz')
mask_path = mask_file.name
mask_file.close()
np.savez_compressed(mask_path, mask=mask)
# Save annotation JSON
json_file = tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w')
json_path = json_file.name
annotation["mask_file"] = mask_path
json.dump(annotation, json_file, indent=2)
json_file.close()
# Create ZIP with both files
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.write(json_path, 'annotation.json')
zf.write(mask_path, 'mask.npz')
zip_buffer.seek(0)
zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip')
zip_path = zip_file.name
zip_file.write(zip_buffer.read())
zip_file.close()
return zip_path, f"✅ Annotation saved: {os.path.basename(zip_path)}"
except Exception as e:
print(f"Error saving annotation: {e}")
return None, f"❌ Save failed: {str(e)}"
def load_annotation(annotation_file):
"""Load a previously saved annotation."""
if annotation_file is None:
return None, None, "⚠️ No file selected"
try:
file_path = annotation_file if isinstance(annotation_file, str) else str(annotation_file)
if file_path.endswith('.zip'):
# Extract ZIP
with zipfile.ZipFile(file_path, 'r') as zf:
# Read annotation JSON
with zf.open('annotation.json') as f:
annotation = json.load(f)
# Extract mask file
mask_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.npz')
mask_temp.write(zf.read('mask.npz'))
mask_temp.close()
mask_data = np.load(mask_temp.name)
mask = mask_data['mask']
info = f"📋 **Loaded Annotation**\n"
info += f"Image: {annotation.get('image_file', 'unknown')}\n"
info += f"Prompt: {annotation.get('prompt', 'N/A')}\n"
info += f"Modality: {annotation.get('modality', 'N/A')}\n"
info += f"Saved: {annotation.get('timestamp', 'N/A')}\n"
info += f"Mask size: {annotation.get('mask_sum', 0):,} pixels"
return mask, annotation, info
else:
return None, None, "⚠️ Invalid file format. Please upload a .zip annotation file."
except Exception as e:
print(f"Error loading annotation: {e}")
return None, None, f"❌ Load failed: {str(e)}"
def visualize_loaded_annotation(image_file, annotation_file, colormap='spring', transparency=0.5):
"""Visualize a loaded annotation on the original image."""
mask, annotation, info = load_annotation(annotation_file)
if mask is None:
return None, info
if image_file is None:
return None, "⚠️ Please upload the original image to visualize"
try:
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.open(file_path).convert('RGB')
# Resize mask if needed
w, h = pil_image.size
if mask.shape != (h, w):
mask = np.array(Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))) > 127
# Visualize
plt.figure(figsize=(10, 10))
plt.imshow(pil_image)
plt.imshow(mask, alpha=transparency, cmap=colormap)
plt.axis('off')
plt.title(f"Loaded Annotation: {annotation.get('prompt', 'N/A')}", fontsize=12)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100)
plt.close()
return output_path, info
except Exception as e:
print(f"Error visualizing annotation: {e}")
return None, f"❌ Visualization failed: {str(e)}"
# Store last mask for export/save operations
last_processed_mask = {"mask": None, "image_file": None, "prompt": None, "modality": None}
def process_and_store_mask(image_file, prompt_text, modality, window_type):
"""Process image and store mask for export/save operations."""
result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
if result and mask is not None:
last_processed_mask["mask"] = mask
last_processed_mask["image_file"] = image_file
last_processed_mask["prompt"] = prompt_text
last_processed_mask["modality"] = modality
# Calculate stats
stats = calculate_roi_statistics(image_file, mask, modality)
stats_text = format_roi_statistics(stats)
return result, "✅ Segmentation complete! Ready for export.", stats_text
else:
return result, "❌ Processing failed.", ""
def export_last_mask_nifti():
"""Export the last processed mask to NIFTI."""
if last_processed_mask["mask"] is None:
return None, "⚠️ No mask to export. Process an image first."
return export_to_nifti(
last_processed_mask["image_file"],
last_processed_mask["mask"]
)
# ============================================================================
# SAM-MEDICAL-IMAGING UTILITIES (Inspired by amine0110/SAM-Medical-Imaging)
# Automatic Mask Generator, Advanced Transforms, Grid-based Segmentation
# ============================================================================
class ResizeLongestSide:
"""
Resizes images to the longest side target length while maintaining aspect ratio.
Inspired by SAM-Medical-Imaging transforms.py
"""
def __init__(self, target_length: int = 1024):
self.target_length = target_length
def apply_image(self, image: np.ndarray) -> np.ndarray:
"""Resize image maintaining aspect ratio."""
h, w = image.shape[:2]
scale = self.target_length / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
pil_image = Image.fromarray(image)
pil_image = pil_image.resize((new_w, new_h), Image.BILINEAR)
return np.array(pil_image)
def apply_coords(self, coords: np.ndarray, original_size: tuple) -> np.ndarray:
"""Resize coordinates to match resized image."""
old_h, old_w = original_size
scale = self.target_length / max(old_h, old_w)
return coords * scale
def apply_boxes(self, boxes: np.ndarray, original_size: tuple) -> np.ndarray:
"""Resize bounding boxes to match resized image."""
boxes = boxes.copy()
boxes[..., :2] = self.apply_coords(boxes[..., :2], original_size)
boxes[..., 2:] = self.apply_coords(boxes[..., 2:], original_size)
return boxes
def generate_grid_points(image_size: tuple, points_per_side: int = 32) -> np.ndarray:
"""
Generate a grid of points for automatic mask generation.
Inspired by SAM AMG (Automatic Mask Generator).
Args:
image_size: (height, width) of the image
points_per_side: Number of points per side of the grid
Returns:
Array of (x, y) point coordinates
"""
h, w = image_size
# Generate evenly spaced points
x_coords = np.linspace(0, w - 1, points_per_side)
y_coords = np.linspace(0, h - 1, points_per_side)
# Create grid
xx, yy = np.meshgrid(x_coords, y_coords)
points = np.stack([xx.flatten(), yy.flatten()], axis=1)
return points
def automatic_mask_generator(image_file, modality, window_type,
points_per_side=16, min_mask_area=100,
colormap='tab20', progress=gr.Progress()):
"""
Automatic Mask Generator (AMG) - Generate masks for entire image without prompts.
Uses a grid of points to query the model and generates masks automatically.
Inspired by SAM-Medical-Imaging's amg.py
"""
if model is None or processor is None:
return None, "❌ Error: Model not loaded.", ""
if image_file is None:
return None, "⚠️ Please upload a medical image file.", ""
try:
progress(0.1, desc="Loading image...")
# Load and preprocess image
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
# Apply windowing
if modality == "CT":
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else:
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
pil_image = Image.fromarray(img_uint8).convert('RGB')
else:
pil_image = Image.open(file_path).convert('RGB')
img_array = np.array(pil_image)
h, w = img_array.shape[:2]
progress(0.2, desc="Generating grid points...")
# Generate grid of points
grid_points = generate_grid_points((h, w), points_per_side)
total_points = len(grid_points)
# Collect all masks
all_masks = []
all_scores = []
progress(0.3, desc=f"Processing {total_points} points...")
# Use different prompts to generate diverse masks
prompts = ["anatomical structure", "region", "tissue", "organ", "segment"]
for prompt_idx, prompt in enumerate(prompts):
progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...")
try:
# Process with SAM 3 - using helper function
# Lower thresholds for medical images to ensure detections are not filtered out
sam_results = run_sam3_inference(pil_image, prompt, threshold=0.1, mask_threshold=0.0)
if sam_results and 'masks' in sam_results and sam_results['masks'] is not None:
masks = sam_results['masks'] # List of mask tensors
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
# Convert to boolean
if mask_np.dtype != bool:
mask_np = mask_np > 0.5
# Filter by minimum area
mask_area = np.sum(mask_np)
if mask_area >= min_mask_area:
# Resize mask to image size
mask_resized = np.array(
Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))
) > 127
all_masks.append(mask_resized)
all_scores.append(mask_area)
except Exception as e:
print(f"Error with prompt '{prompt}': {e}")
continue
progress(0.85, desc="Combining masks...")
if not all_masks:
return None, "⚠️ No masks generated. Try different parameters.", ""
# Remove duplicate/overlapping masks using Non-Maximum Suppression
unique_masks = []
for mask in all_masks:
is_duplicate = False
for existing in unique_masks:
# Check IoU
intersection = np.logical_and(mask, existing).sum()
union = np.logical_or(mask, existing).sum()
iou = intersection / union if union > 0 else 0
if iou > 0.8: # High overlap threshold
is_duplicate = True
break
if not is_duplicate:
unique_masks.append(mask)
progress(0.9, desc="Creating visualization...")
# Create visualization with all masks
plt.figure(figsize=(12, 12))
plt.imshow(pil_image)
# Create colored overlay for all masks
cmap = plt.cm.get_cmap(colormap)
for idx, mask in enumerate(unique_masks):
color = cmap(idx / max(len(unique_masks), 1))[:3]
colored_mask = np.zeros((*mask.shape, 4))
colored_mask[mask] = (*color, 0.4)
plt.imshow(colored_mask)
plt.axis('off')
plt.title(f"Automatic Mask Generation: {len(unique_masks)} regions detected", fontsize=14)
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150)
plt.close()
progress(1.0, desc="Complete!")
# Generate info text
info_text = f"📊 **AMG Results**\n\n"
info_text += f"**Total Regions:** {len(unique_masks)}\n"
info_text += f"**Grid Points:** {points_per_side}x{points_per_side} = {points_per_side**2}\n"
info_text += f"**Min Area Filter:** {min_mask_area} pixels\n\n"
for idx, mask in enumerate(unique_masks[:10]): # Show first 10
area = np.sum(mask)
percentage = area / (h * w) * 100
info_text += f"Region {idx+1}: {area:,} pixels ({percentage:.2f}%)\n"
if len(unique_masks) > 10:
info_text += f"\n... and {len(unique_masks) - 10} more regions"
return output_path, f"✅ AMG Complete! Found {len(unique_masks)} regions.", info_text
except Exception as e:
print(f"Error in AMG: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}", ""
def process_with_advanced_transforms(image_file, prompt_text, modality, window_type,
target_size=1024, apply_clahe=False,
clahe_clip=2.0, colormap='spring', transparency=0.5):
"""
Process image with advanced transforms from SAM-Medical-Imaging.
- ResizeLongestSide: Maintains aspect ratio
- CLAHE: Contrast Limited Adaptive Histogram Equalization (optional)
"""
if model is None or processor is None:
return None, "❌ Error: Model not loaded."
if image_file is None:
return None, "⚠️ Please upload a medical image file."
try:
# Load image
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
# Apply windowing
if modality == "CT":
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else:
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
else:
img = Image.open(file_path)
if img.mode != 'L':
img = img.convert('L')
img_uint8 = np.array(img)
original_size = img_uint8.shape[:2]
# Apply CLAHE if requested
if apply_clahe:
try:
from scipy.ndimage import uniform_filter
# Simple CLAHE-like enhancement using local contrast
local_mean = uniform_filter(img_uint8.astype(float), size=50)
local_std = np.sqrt(uniform_filter(img_uint8.astype(float)**2, size=50) - local_mean**2 + 1e-8)
# Enhance contrast
enhanced = (img_uint8 - local_mean) / (local_std + 1e-8) * clahe_clip
enhanced = np.clip(enhanced * 30 + 128, 0, 255).astype(np.uint8)
img_uint8 = enhanced
except Exception as e:
print(f"CLAHE enhancement failed: {e}")
# Apply ResizeLongestSide transform
transform = ResizeLongestSide(target_size)
if len(img_uint8.shape) == 2:
img_uint8_3ch = np.stack([img_uint8] * 3, axis=-1)
else:
img_uint8_3ch = img_uint8
img_resized = transform.apply_image(img_uint8_3ch)
pil_image = Image.fromarray(img_resized)
if not prompt_text or not prompt_text.strip():
prompt_text = "brain"
# Process with SAM 3 - using helper function
# Lower thresholds for medical images to ensure detections are not filtered out
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0)
final_mask = None
if results and 'masks' in results and results['masks'] is not None:
masks = results['masks']
# Combine all masks
mask_arrays = []
for mask in masks:
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy()
else:
mask_np = np.array(mask)
mask_arrays.append(mask_np)
if len(mask_arrays) > 0:
final_mask = np.any(mask_arrays, axis=0)
# Visualize
plt.figure(figsize=(12, 6))
# Original
plt.subplot(1, 2, 1)
plt.imshow(Image.fromarray(img_uint8_3ch[:, :, 0] if len(img_uint8_3ch.shape) == 3 else img_uint8_3ch), cmap='gray')
plt.title(f"Original ({original_size[0]}x{original_size[1]})", fontsize=10)
plt.axis('off')
# Processed with mask
plt.subplot(1, 2, 2)
plt.imshow(pil_image)
if final_mask is not None:
# Resize mask to match processed image
mask_h, mask_w = final_mask.shape
img_h, img_w = img_resized.shape[:2]
if (mask_h, mask_w) != (img_h, img_w):
final_mask = np.array(
Image.fromarray(final_mask.astype(np.uint8) * 255).resize((img_w, img_h))
) > 127
plt.imshow(final_mask, alpha=transparency, cmap=colormap)
plt.title(f"Transformed ({target_size}px) + Segmentation", fontsize=10)
plt.axis('off')
plt.tight_layout()
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', dpi=100)
plt.close()
status = f"✅ Processed with ResizeLongestSide({target_size})"
if apply_clahe:
status += f" + CLAHE(clip={clahe_clip})"
return output_path, status
except Exception as e:
print(f"Error in advanced transforms: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def edge_based_segmentation(image_file, modality, window_type,
edge_threshold=50, dilation_size=3,
colormap='spring', transparency=0.5):
"""
Edge-based automatic segmentation using Sobel/Canny-like detection.
Useful for finding boundaries in medical images.
"""
if image_file is None:
return None, "⚠️ Please upload a medical image file."
try:
# Load image
file_path = image_file if isinstance(image_file, str) else str(image_file)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.dcm':
ds = pydicom.dcmread(file_path)
img_array = ds.pixel_array.astype(np.float32)
slope = getattr(ds, 'RescaleSlope', 1)
intercept = getattr(ds, 'RescaleIntercept', 0)
img_array = img_array * slope + intercept
if modality == "CT":
if window_type == "Brain (Grey Matter)":
level, width = 40, 80
elif window_type == "Bone (Skull)":
level, width = 500, 2000
else:
level, width = 40, 400
img_min = level - (width / 2)
img_max = level + (width / 2)
else:
img_min = np.percentile(img_array, 1)
img_max = np.percentile(img_array, 99)
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1)
img_uint8 = (img_norm * 255).astype(np.uint8)
else:
img = Image.open(file_path).convert('L')
img_uint8 = np.array(img)
# Compute Sobel edges
from scipy.ndimage import sobel, binary_dilation, binary_fill_holes
# Sobel edge detection
dx = sobel(img_uint8.astype(float), axis=1)
dy = sobel(img_uint8.astype(float), axis=0)
edges = np.hypot(dx, dy)
# Threshold edges
edge_mask = edges > edge_threshold
# Dilate edges to connect nearby components
if dilation_size > 0:
struct = np.ones((dilation_size, dilation_size))
edge_mask = binary_dilation(edge_mask, structure=struct)
# Fill holes
filled_mask = binary_fill_holes(edge_mask)
# Label connected components
labeled, num_features = ndimage.label(filled_mask)
# Create RGB image for display
pil_image = Image.fromarray(img_uint8).convert('RGB')
# Visualize
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(img_uint8, cmap='gray')
plt.title("Original Image", fontsize=10)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(edges, cmap='hot')
plt.title(f"Edge Detection (threshold={edge_threshold})", fontsize=10)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pil_image)
plt.imshow(filled_mask, alpha=transparency, cmap=colormap)
plt.title(f"Segmentation ({num_features} regions)", fontsize=10)
plt.axis('off')
plt.tight_layout()
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
output_path = output_file.name
output_file.close()
plt.savefig(output_path, bbox_inches='tight', dpi=100)
plt.close()
return output_path, f"✅ Edge-based segmentation complete! Found {num_features} regions."
except Exception as e:
print(f"Error in edge segmentation: {e}")
import traceback
traceback.print_exc()
return None, f"❌ Error: {str(e)}"
def save_last_annotation():
"""Save the last processed annotation."""
if last_processed_mask["mask"] is None:
return None, "⚠️ No annotation to save. Process an image first."
stats = calculate_roi_statistics(
last_processed_mask["image_file"],
last_processed_mask["mask"],
last_processed_mask["modality"]
)
return save_annotation(
last_processed_mask["image_file"],
last_processed_mask["mask"],
last_processed_mask["prompt"],
last_processed_mask["modality"],
stats
)
# Create Gradio Interface
demo_file_path = demo_dicom_path if demo_file_available and os.path.exists(demo_dicom_path) else None
def load_demo_file():
"""Load the demo DICOM file."""
if demo_file_path and os.path.exists(demo_file_path):
return demo_file_path, f"✅ Demo file loaded: {demo_file_path}\nReady to segment!"
else:
return None, "⚠️ Demo file not found. Please upload a medical image file (DICOM, PNG, or JPG)."
def process_with_status(image_file, prompt_text, modality, window_type):
"""Wrapper function to update status during processing."""
if model is None or processor is None:
return None, "❌ Error: Model not loaded."
if image_file is None:
return None, "⚠️ Please upload a medical image file (DICOM, PNG, or JPG) or load the demo file."
result = process_medical_image(image_file, prompt_text, modality, window_type)
if result is None:
return None, "❌ Processing failed. Check console for error details."
else:
return result, "✅ Segmentation complete!"
def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type):
"""Process image and compare with ground truth segmentation mask."""
if model is None or processor is None:
return None, None, 0.0, 0.0, "❌ Error: Model not loaded."
if image_file is None:
return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file."
if gt_mask_file is None:
return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file."
# Process image and get mask
result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True)
if result is None or pred_mask is None:
return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details."
# Compare with ground truth
comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file)
if comparison_path:
status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}"
return result, comparison_path, dice_score, iou_score, status
else:
return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed."
def process_sequence(image_files, prompt_text, modality, window_type):
"""Process multiple images from the same subject and return gallery of results."""
if model is None or processor is None:
return [], "❌ Error: Model not loaded."
if not image_files:
return [], "⚠️ Please upload medical image files (DICOM, PNG, or JPG)."
# Handle single file or list of files
if isinstance(image_files, str):
image_files = [image_files]
results = []
status_messages = []
for idx, image_file in enumerate(image_files):
if image_file is None:
continue
status_msg = f"Processing image {idx + 1}/{len(image_files)}..."
status_messages.append(status_msg)
result = process_medical_image(image_file, prompt_text, modality, window_type)
if result:
results.append(result)
status_messages.append(f"✅ Image {idx + 1} segmented successfully")
else:
status_messages.append(f"❌ Failed to process image {idx + 1}")
if results:
status = f"✅ Processed {len(results)}/{len(image_files)} images successfully!\n" + "\n".join(status_messages)
return results, status
else:
return [], "❌ No images were processed successfully. Check console for error details."
# Store processed results for interactive viewer
processed_results_cache = {}
def extract_subject_id(file_path):
"""Extract subject/patient ID from file path.
Common patterns:
- Folder name: /subject_001/image.png -> subject_001
- Filename prefix: subject_001_slice_01.png -> subject_001
- Patient ID in filename: patient_123_slice_5.dcm -> patient_123
- Study UID in DICOM: extract from DICOM metadata
Returns:
tuple: (subject_id, confidence_level, source)
confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback)
source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback'
"""
import re
file_path = str(file_path)
filename = os.path.basename(file_path)
dir_path = os.path.dirname(file_path)
# HIGHEST CONFIDENCE: DICOM metadata (most reliable)
if file_path.lower().endswith('.dcm'):
try:
ds = pydicom.dcmread(file_path, stop_before_pixels=True)
patient_id = getattr(ds, 'PatientID', None)
if patient_id and patient_id.strip():
return f"patient_{patient_id}", 'high', 'dicom_patientid'
study_uid = getattr(ds, 'StudyInstanceUID', None)
if study_uid:
# Use full study UID as identifier (unique per study)
return f"study_{study_uid}", 'high', 'dicom_study'
except:
pass
# MEDIUM CONFIDENCE: Folder name (common in medical datasets)
folder_name = os.path.basename(dir_path.rstrip('/'))
if folder_name and folder_name not in ['', '.', '..']:
# Check if folder name looks like a subject ID
if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I):
return folder_name, 'medium', 'folder'
# MEDIUM CONFIDENCE: Filename pattern
patterns = [
(r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'), # subject_001, patient_123
(r'([A-Z]{2,}\d+)', 'medium'), # BR001, MR123, etc.
]
for pattern, confidence in patterns:
match = re.search(pattern, filename, re.I)
if match:
if len(match.groups()) > 1:
return f"{match.group(1)}_{match.group(2)}", confidence, 'filename'
else:
return match.group(1), confidence, 'filename'
# LOW CONFIDENCE: Numeric pattern (could be slice number, not patient ID)
numeric_match = re.search(r'(\d{3,})', filename)
if numeric_match:
return numeric_match.group(1), 'low', 'filename_numeric'
# LOWEST CONFIDENCE: Fallback to filename
base_name = os.path.splitext(filename)[0]
if len(base_name) > 0:
return base_name, 'low', 'fallback'
return "unknown", 'low', 'unknown'
def group_images_by_subject(image_files):
"""Group image files by subject/patient ID.
Returns:
dict: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}}
"""
if not image_files:
return {}
if isinstance(image_files, str):
image_files = [image_files]
# Filter out None files
image_files = [f for f in image_files if f is not None]
# Group by subject ID and track confidence
subject_groups = {}
for file_path in image_files:
subject_id, confidence, source = extract_subject_id(file_path)
if subject_id not in subject_groups:
subject_groups[subject_id] = {
'files': [],
'confidence': confidence,
'sources': set([source])
}
subject_groups[subject_id]['files'].append(file_path)
subject_groups[subject_id]['sources'].add(source)
# Upgrade confidence if we find high-confidence source
if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'):
subject_groups[subject_id]['confidence'] = confidence
# Sort files within each group (by filename)
for subject_id in subject_groups:
subject_groups[subject_id]['files'].sort()
subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources'])
return subject_groups
def detect_subjects(image_files):
"""Detect and return subject groups from uploaded files."""
if not image_files:
return gr.Dropdown(choices=[], value=None), "No files uploaded"
subject_groups = group_images_by_subject(image_files)
if not subject_groups:
return gr.Dropdown(choices=[], value=None), "No subjects detected"
choices = []
status_msg = f"✅ Detected {len(subject_groups)} subject(s):\n\n"
for subject_id, info in sorted(subject_groups.items()):
num_files = len(info['files'])
confidence = info['confidence']
sources = ', '.join(info['sources'])
# Add confidence indicator
if confidence == 'high':
confidence_icon = "✅"
confidence_text = "HIGH (DICOM metadata - very reliable)"
elif confidence == 'medium':
confidence_icon = "⚠️"
confidence_text = "MEDIUM (folder/filename pattern - likely same patient)"
else:
confidence_icon = "⚠️⚠️"
confidence_text = "LOW (filename-based - verify manually)"
choices.append(f"{subject_id} ({num_files} slices)")
status_msg += f"{confidence_icon} **{subject_id}**: {num_files} slices\n"
status_msg += f" Confidence: {confidence_text}\n"
status_msg += f" Source: {sources}\n\n"
# Add warning for low confidence
low_confidence_count = sum(1 for info in subject_groups.values() if info['confidence'] == 'low')
if low_confidence_count > 0:
status_msg += f"⚠️ **Warning**: {low_confidence_count} subject(s) detected with LOW confidence.\n"
status_msg += "Please verify these are actually the same patient before proceeding.\n"
return gr.Dropdown(choices=choices, value=choices[0] if choices else None), status_msg
def process_slices_for_viewer(image_files, selected_subject, prompt_text, modality, window_type):
"""Process all slices for selected subject and cache results for interactive viewing."""
if model is None or processor is None:
return None, 0, "❌ Error: Model not loaded.", "No slices loaded", gr.Dropdown(choices=[], value=None), ""
if not image_files:
return None, 0, "⚠️ Please upload medical image files.", "No slices loaded", gr.Dropdown(choices=[], value=None), ""
# Group by subject
subject_groups = group_images_by_subject(image_files)
if not subject_groups:
return None, 0, "⚠️ Could not detect subjects in uploaded files.", "No slices loaded", gr.Dropdown(choices=[], value=None), ""
# Extract subject ID from selection (format: "subject_id (N slices)")
if selected_subject:
subject_id = selected_subject.split(" (")[0]
else:
# Use first subject if none selected
subject_id = list(subject_groups.keys())[0]
if subject_id not in subject_groups:
return None, 0, f"⚠️ Subject '{subject_id}' not found.", "No slices loaded", gr.Dropdown(choices=[], value=None), ""
# Get files for selected subject
subject_info = subject_groups[subject_id]
subject_files = subject_info['files']
confidence = subject_info['confidence']
# Add confidence warning
confidence_warning = ""
if confidence == 'low':
confidence_warning = "\n⚠️ LOW CONFIDENCE: These files may not be from the same patient. Please verify!"
elif confidence == 'medium':
confidence_warning = "\n⚠️ MEDIUM CONFIDENCE: Likely same patient, but verify if critical."
results = []
status_messages = []
for idx, image_file in enumerate(subject_files):
status_msg = f"Processing slice {idx + 1}/{len(subject_files)}..."
status_messages.append(status_msg)
result = process_medical_image(image_file, prompt_text, modality, window_type)
if result:
results.append(result)
status_messages.append(f"✅ Slice {idx + 1} processed")
else:
status_messages.append(f"❌ Failed to process slice {idx + 1}")
if results:
# Cache results with a unique key including subject ID
cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}"
processed_results_cache[cache_key] = results
max_slices = len(results) - 1
status = f"✅ Processed {len(results)}/{len(subject_files)} slices for {subject_id}!\nUse slider or buttons to navigate.{confidence_warning}"
slice_info = f"Slice 1/{len(results)} ({subject_id})"
# Update subject dropdown choices
choices = []
for sid, info in sorted(subject_groups.items()):
marker = "✓" if sid == subject_id else ""
num_files = len(info['files'])
choices.append(f"{marker} {sid} ({num_files} slices)")
return results[0], max_slices, status, slice_info, gr.Dropdown(choices=choices, value=choices[0] if choices else None), f"Viewing: {subject_id}"
else:
return None, 0, "❌ No slices were processed successfully.", "No slices loaded", gr.Dropdown(choices=[], value=None), ""
def navigate_slice(slice_idx, image_files, selected_subject, prompt_text, modality, window_type):
"""Navigate to a specific slice in the sequence."""
if not image_files:
return None, "No slices loaded"
# Group by subject and get selected subject's files
subject_groups = group_images_by_subject(image_files)
if selected_subject:
subject_id = selected_subject.split(" (")[0]
else:
subject_id = list(subject_groups.keys())[0] if subject_groups else None
if not subject_id or subject_id not in subject_groups:
return None, "No slices loaded"
subject_info = subject_groups[subject_id]
subject_files = subject_info['files']
slice_idx = int(slice_idx)
cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}"
if cache_key in processed_results_cache:
results = processed_results_cache[cache_key]
if 0 <= slice_idx < len(results):
slice_info = f"Slice {slice_idx + 1}/{len(results)} ({subject_id})"
return results[slice_idx], slice_info
# If not cached, process on the fly (fallback)
if 0 <= slice_idx < len(subject_files):
result = process_medical_image(subject_files[slice_idx], prompt_text, modality, window_type)
if result:
slice_info = f"Slice {slice_idx + 1}/{len(subject_files)} ({subject_id})"
return result, slice_info
return None, f"Invalid slice index: {slice_idx}"
with gr.Blocks() as demo:
gr.Markdown("# 🏥 NeuroSAM 3: Medical Image Segmentation")
demo_info = ""
if demo_file_path:
demo_info = f"\n\n**📁 Demo File Available:** A sample DICOM file is ready: `{demo_file_path}`\nClick 'Load Demo File' button below to use it!"
gr.Markdown(f"""
Upload a medical image (DICOM .dcm, PNG, or JPG) and type what you want to find (e.g., 'brain', 'tumor', 'skull').
{demo_info}
**Instructions:**
1. Upload a medical image file:
- DICOM (.dcm) files for CT or MRI scans
- PNG/JPG files (e.g., from Kaggle datasets like brain MRI images)
2. Enter a text prompt describing what to segment
3. Select the imaging modality (CT or MRI)
4. Choose the windowing strategy (for CT images)
5. Click "Segment Structure" to process
**Supported Formats:**
- DICOM (.dcm) - Standard medical imaging format
- PNG/JPG - Standard image formats (works with Kaggle brain MRI datasets)
""")
with gr.Tabs():
with gr.Tab("Single Image"):
with gr.Row():
with gr.Column():
file_input = gr.File(
label="Upload Medical Image (DICOM .dcm, PNG, JPG)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath",
value=demo_file_path
)
load_demo_btn = gr.Button(
"📁 Load Demo File",
variant="secondary",
size="sm",
visible=bool(demo_file_path)
)
text_input = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull, eyes",
info="Describe what anatomical structure or region you want to segment"
)
with gr.Row():
modality_dropdown = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI",
info="Select the imaging modality"
)
window_dropdown = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)",
info="CT windowing preset (ignored for MRI)"
)
submit_btn = gr.Button("Segment Structure", variant="primary", size="lg")
with gr.Column():
image_output = gr.Image(
label="Segmentation Result",
type="filepath"
)
gr.Markdown("### Status")
status_text = gr.Textbox(
label="Processing Status",
value="Ready. Upload a medical image file (DICOM, PNG, or JPG) to begin.",
interactive=False
)
with gr.Tab("Interactive Slice Viewer"):
gr.Markdown("**Scroll through multiple slices/images from the same subject interactively**")
gr.Markdown("""
**📋 Subject Detection:** The app automatically detects subject/patient IDs from:
- Folder names (e.g., `subject_001/`, `patient_123/`)
- Filenames (e.g., `subject_001_slice_01.png`, `patient_123.dcm`)
- DICOM metadata (PatientID, StudyInstanceUID)
**💡 Tip:** Upload images organized by subject folders for best results!
""")
with gr.Row():
with gr.Column():
files_input = gr.File(
label="Upload Multiple Images/Slices (Select multiple files)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
file_count="multiple",
type="filepath"
)
subject_dropdown = gr.Dropdown(
label="Select Subject/Patient",
choices=[],
value=None,
interactive=True,
info="Select which subject's slices to view (auto-detected from filenames/folders)"
)
text_input_batch = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull, eyes",
info="Describe what anatomical structure or region you want to segment"
)
with gr.Row():
modality_dropdown_batch = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI",
info="Select the imaging modality"
)
window_dropdown_batch = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)",
info="CT windowing preset (ignored for MRI)"
)
detect_subjects_btn = gr.Button("🔍 Detect Subjects", variant="secondary", size="sm")
submit_batch_btn = gr.Button("Process All Slices", variant="primary", size="lg")
gr.Markdown("---")
gr.Markdown("### 🎛️ Slice Navigator")
slice_slider = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label="Slice Number",
info="Use slider or arrow keys to navigate through slices",
interactive=False
)
with gr.Row():
prev_btn = gr.Button("⬆️ Previous Slice", size="sm")
next_btn = gr.Button("⬇️ Next Slice", size="sm")
auto_play_btn = gr.Button("▶️ Auto-play", size="sm")
with gr.Column():
current_slice_output = gr.Image(
label="Current Slice Segmentation",
type="filepath",
height=600
)
gr.Markdown("### Slice Info")
slice_info_text = gr.Textbox(
label="Current Slice",
value="No slices loaded",
interactive=False
)
subject_info_text = gr.Textbox(
label="Subject Info",
value="",
interactive=False,
visible=False
)
gr.Markdown("### Status")
status_batch_text = gr.Textbox(
label="Processing Status",
value="Ready. Upload multiple medical image files to process a sequence.",
interactive=False,
lines=4
)
with gr.Tab("Gallery View"):
gr.Markdown("**View all segmentations in a gallery grid**")
with gr.Row():
with gr.Column():
files_input_gallery = gr.File(
label="Upload Multiple Images (Select multiple files)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
file_count="multiple",
type="filepath"
)
text_input_gallery = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull, eyes"
)
with gr.Row():
modality_dropdown_gallery = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI"
)
window_dropdown_gallery = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)"
)
submit_gallery_btn = gr.Button("Process & Show Gallery", variant="primary", size="lg")
with gr.Column():
gallery_output = gr.Gallery(
label="Segmentation Gallery",
show_label=True,
elem_id="gallery",
columns=2,
rows=2,
height="auto"
)
status_gallery_text = gr.Textbox(
label="Status",
value="Ready. Upload multiple images to view in gallery.",
interactive=False
)
with gr.Tab("Compare with Ground Truth"):
gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**")
with gr.Row():
with gr.Column():
file_input_gt = gr.File(
label="Upload Medical Image (DICOM .dcm, PNG, JPG)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
gt_mask_input = gr.File(
label="Upload Ground Truth Mask (PNG, JPG)",
file_types=[".png", ".jpg", ".jpeg"],
type="filepath"
)
text_input_gt = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull",
info="Describe what anatomical structure or region you want to segment"
)
with gr.Row():
modality_dropdown_gt = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI",
info="Select the imaging modality"
)
window_dropdown_gt = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)",
info="CT windowing preset (ignored for MRI)"
)
submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg")
with gr.Column():
image_output_gt = gr.Image(
label="SAM 3 Segmentation",
type="filepath"
)
comparison_output = gr.Image(
label="Comparison: SAM 3 vs Ground Truth",
type="filepath"
)
gr.Markdown("### Metrics")
dice_score_text = gr.Textbox(
label="Dice Score",
value="--",
interactive=False
)
iou_score_text = gr.Textbox(
label="IoU Score",
value="--",
interactive=False
)
gr.Markdown("### Status")
status_gt_text = gr.Textbox(
label="Processing Status",
value="Ready. Upload image and ground truth mask to compare.",
interactive=False,
lines=3
)
# NEW ENHANCED TABS
with gr.Tab("✨ Enhanced Single Image"):
gr.Markdown("**Enhanced version with image adjustments, visualization options, and progress tracking**")
with gr.Row():
with gr.Column():
file_input_enh = gr.File(
label="Upload Medical Image (DICOM .dcm, PNG, JPG)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath",
value=demo_file_path
)
text_input_enh = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull, eyes",
info="Describe what anatomical structure or region you want to segment"
)
with gr.Row():
modality_enh = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI",
info="Select the imaging modality"
)
window_enh = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)",
info="CT windowing preset (ignored for MRI)"
)
# NEW: Image adjustment controls
gr.Markdown("### 🎨 Image Adjustments")
brightness_slider = gr.Slider(
0.5, 2.0, value=1.0, step=0.1,
label="Brightness",
info="Adjust image brightness (0.5 = darker, 2.0 = brighter)"
)
contrast_slider = gr.Slider(
0.5, 2.0, value=1.0, step=0.1,
label="Contrast",
info="Adjust image contrast (0.5 = lower contrast, 2.0 = higher contrast)"
)
# NEW: Visualization options
gr.Markdown("### 🎭 Visualization Options")
colormap_dropdown = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"],
label="Mask Colormap",
value="spring",
info="Color scheme for segmentation overlay"
)
transparency_slider = gr.Slider(
0.0, 1.0, value=0.5, step=0.1,
label="Mask Transparency",
info="Transparency of segmentation overlay (0.0 = opaque, 1.0 = transparent)"
)
submit_enh_btn = gr.Button("Segment with Progress", variant="primary", size="lg")
with gr.Column():
# Progress indicator
progress_text = gr.Textbox(
label="Progress",
value="Ready. Upload a medical image file to begin.",
interactive=False
)
image_output_enh = gr.Image(
label="Segmentation Result",
type="filepath"
)
# NEW: Download button
download_output = gr.File(
label="Download Result",
visible=True
)
# NEW: Metrics display
gr.Markdown("### 📊 Metrics")
metrics_text = gr.Textbox(
label="Segmentation Info",
value="",
lines=3,
interactive=False
)
with gr.Tab("✨ Enhanced Batch Processing"):
gr.Markdown("**Enhanced batch processing with ZIP download and progress tracking**")
with gr.Row():
with gr.Column():
files_input_enh_batch = gr.File(
label="Upload Multiple Images for Batch Processing",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
file_count="multiple",
type="filepath"
)
text_input_enh_batch = gr.Textbox(
label="Text Prompt",
value="brain",
placeholder="e.g. brain, tumor, skull, eyes"
)
with gr.Row():
modality_enh_batch = gr.Dropdown(
["CT", "MRI"],
label="Modality",
value="MRI"
)
window_enh_batch = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing Strategy (CT only)",
value="Brain (Grey Matter)"
)
# Image adjustments
gr.Markdown("### 🎨 Image Adjustments")
brightness_slider_batch = gr.Slider(
0.5, 2.0, value=1.0, step=0.1,
label="Brightness"
)
contrast_slider_batch = gr.Slider(
0.5, 2.0, value=1.0, step=0.1,
label="Contrast"
)
# Visualization options
gr.Markdown("### 🎭 Visualization Options")
colormap_dropdown_batch = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"],
label="Mask Colormap",
value="spring"
)
transparency_slider_batch = gr.Slider(
0.0, 1.0, value=0.5, step=0.1,
label="Mask Transparency"
)
submit_enh_batch_btn = gr.Button("Process Batch with Progress", variant="primary", size="lg")
with gr.Column():
gallery_output_enh = gr.Gallery(
label="Segmentation Gallery",
show_label=True,
elem_id="gallery",
columns=2,
rows=2,
height="auto"
)
# Batch download ZIP
batch_download_output = gr.File(
label="Download All Results (ZIP)",
visible=True
)
status_enh_batch_text = gr.Textbox(
label="Status",
value="Ready. Upload multiple images to process in batch.",
interactive=False,
lines=4
)
# NEW: Point/Box Prompts Tab
with gr.Tab("🎯 Point/Box Prompts"):
gr.Markdown("""
**Interactive Point and Box-based Segmentation**
Use precise point clicks or bounding boxes to guide the segmentation.
- **Point Prompt**: Click on the region you want to segment
- **Box Prompt**: Define a bounding box around the region of interest
""")
with gr.Tabs():
with gr.Tab("Point Prompt"):
with gr.Row():
with gr.Column():
file_input_point = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
gr.Markdown("### Point Coordinates")
with gr.Row():
point_x = gr.Number(label="X coordinate", value=128, precision=0)
point_y = gr.Number(label="Y coordinate", value=128, precision=0)
with gr.Row():
modality_point = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_point = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
with gr.Row():
colormap_point = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma"],
label="Colormap", value="spring"
)
transparency_point = gr.Slider(0.0, 1.0, value=0.5, label="Transparency")
submit_point_btn = gr.Button("Segment at Point", variant="primary")
with gr.Column():
output_point = gr.Image(label="Point Segmentation Result", type="filepath")
status_point = gr.Textbox(label="Status", interactive=False)
with gr.Tab("Box Prompt"):
with gr.Row():
with gr.Column():
file_input_box = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
gr.Markdown("### Bounding Box Coordinates")
with gr.Row():
box_x1 = gr.Number(label="X1 (left)", value=50, precision=0)
box_y1 = gr.Number(label="Y1 (top)", value=50, precision=0)
with gr.Row():
box_x2 = gr.Number(label="X2 (right)", value=200, precision=0)
box_y2 = gr.Number(label="Y2 (bottom)", value=200, precision=0)
with gr.Row():
modality_box = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_box = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
with gr.Row():
colormap_box = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma"],
label="Colormap", value="spring"
)
transparency_box = gr.Slider(0.0, 1.0, value=0.5, label="Transparency")
submit_box_btn = gr.Button("Segment in Box", variant="primary")
with gr.Column():
output_box = gr.Image(label="Box Segmentation Result", type="filepath")
status_box = gr.Textbox(label="Status", interactive=False)
# NEW: ROI Statistics & Export Tab
with gr.Tab("📊 ROI Statistics & Export"):
gr.Markdown("""
**ROI Statistics and Export Options**
Process an image and get detailed statistics about the segmented region:
- Area (pixels and percentage)
- Intensity statistics (mean, std, min, max)
- Centroid and bounding box
- Export to NIFTI format for medical imaging software
- Save/Load annotations for later use
""")
with gr.Row():
with gr.Column():
file_input_stats = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
text_input_stats = gr.Textbox(
label="Text Prompt", value="brain",
placeholder="e.g. brain, tumor, skull"
)
with gr.Row():
modality_stats = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_stats = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
submit_stats_btn = gr.Button("Process & Calculate Statistics", variant="primary")
gr.Markdown("### Export Options")
with gr.Row():
export_nifti_btn = gr.Button("📥 Export to NIFTI", size="sm")
save_annotation_btn = gr.Button("💾 Save Annotation", size="sm")
with gr.Column():
output_stats = gr.Image(label="Segmentation Result", type="filepath")
status_stats = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### 📊 ROI Statistics")
roi_stats_text = gr.Markdown(value="*Process an image to see statistics*")
nifti_download = gr.File(label="Download NIFTI", visible=True)
annotation_download = gr.File(label="Download Annotation", visible=True)
gr.Markdown("---")
gr.Markdown("### Load Saved Annotation")
with gr.Row():
with gr.Column():
annotation_upload = gr.File(
label="Upload Annotation (.zip)",
file_types=[".zip"],
type="filepath"
)
original_image_upload = gr.File(
label="Upload Original Image (for visualization)",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
load_annotation_btn = gr.Button("Load & Visualize Annotation", variant="secondary")
with gr.Column():
loaded_annotation_output = gr.Image(label="Loaded Annotation", type="filepath")
loaded_annotation_info = gr.Markdown(value="*Upload an annotation file to load*")
# NEW: Multi-Mask Output Tab
with gr.Tab("🎭 Multi-Mask Output"):
gr.Markdown("""
**Generate Multiple Mask Candidates**
SAM can generate multiple segmentation hypotheses with confidence scores.
This is useful when the segmentation is ambiguous or you want to compare alternatives.
""")
with gr.Row():
with gr.Column():
file_input_multi = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
text_input_multi = gr.Textbox(
label="Text Prompt", value="brain",
placeholder="e.g. brain, tumor, skull"
)
with gr.Row():
modality_multi = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_multi = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
num_masks_slider = gr.Slider(1, 5, value=3, step=1, label="Number of Masks")
submit_multi_btn = gr.Button("Generate Multiple Masks", variant="primary")
with gr.Column():
gallery_multi = gr.Gallery(
label="Mask Candidates",
show_label=True,
columns=2,
rows=2,
height="auto"
)
status_multi = gr.Textbox(label="Status", interactive=False)
mask_info_multi = gr.Textbox(label="Mask Information", lines=5, interactive=False)
# NEW: SAM-Medical-Imaging Inspired Tabs
with gr.Tab("🔬 Automatic Mask Generator"):
gr.Markdown("""
**Automatic Mask Generator (AMG)**
Inspired by [SAM-Medical-Imaging](https://github.com/amine0110/SAM-Medical-Imaging) -
automatically segment the entire image without needing specific prompts.
- Uses multiple prompts internally to discover all regions
- Filters and deduplicates overlapping masks
- Great for exploratory analysis of medical images
""")
with gr.Row():
with gr.Column():
file_input_amg = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
with gr.Row():
modality_amg = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_amg = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
gr.Markdown("### AMG Parameters")
points_per_side = gr.Slider(
8, 32, value=16, step=4,
label="Grid Density",
info="Higher = more detailed but slower"
)
min_mask_area = gr.Slider(
50, 1000, value=100, step=50,
label="Minimum Mask Area (pixels)",
info="Filter out small noise regions"
)
colormap_amg = gr.Dropdown(
["tab20", "tab10", "Set1", "Set2", "Paired", "rainbow"],
label="Colormap for Regions",
value="tab20"
)
submit_amg_btn = gr.Button("🔬 Run Automatic Segmentation", variant="primary", size="lg")
with gr.Column():
output_amg = gr.Image(label="AMG Result", type="filepath")
status_amg = gr.Textbox(label="Status", interactive=False)
info_amg = gr.Markdown(value="*Run AMG to see region details*")
with gr.Tab("🔧 Advanced Transforms"):
gr.Markdown("""
**Advanced Image Transforms**
Apply preprocessing transforms from SAM-Medical-Imaging before segmentation:
- **ResizeLongestSide**: Resize maintaining aspect ratio (optimal for SAM)
- **CLAHE-like Enhancement**: Enhance local contrast for better visibility
""")
with gr.Row():
with gr.Column():
file_input_transform = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
text_input_transform = gr.Textbox(
label="Text Prompt", value="brain",
placeholder="e.g. brain, tumor, skull"
)
with gr.Row():
modality_transform = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_transform = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
gr.Markdown("### Transform Settings")
target_size_slider = gr.Slider(
256, 2048, value=1024, step=128,
label="Target Size (ResizeLongestSide)",
info="Resize image's longest side to this value"
)
apply_clahe_checkbox = gr.Checkbox(
label="Apply CLAHE-like Enhancement",
value=False,
info="Enhance local contrast (useful for low-contrast images)"
)
clahe_clip_slider = gr.Slider(
1.0, 4.0, value=2.0, step=0.5,
label="CLAHE Clip Limit",
info="Higher = more contrast enhancement"
)
with gr.Row():
colormap_transform = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma"],
label="Colormap", value="spring"
)
transparency_transform = gr.Slider(
0.0, 1.0, value=0.5, step=0.1,
label="Transparency"
)
submit_transform_btn = gr.Button("🔧 Process with Transforms", variant="primary")
with gr.Column():
output_transform = gr.Image(label="Transform + Segmentation Result", type="filepath")
status_transform = gr.Textbox(label="Status", interactive=False)
with gr.Tab("🌊 Edge-Based Segmentation"):
gr.Markdown("""
**Edge-Based Automatic Segmentation**
Uses Sobel edge detection to find boundaries in medical images.
Works without AI model - useful for comparison or when model fails.
- Detects edges using gradient-based methods
- Fills regions within boundaries
- Identifies distinct anatomical structures
""")
with gr.Row():
with gr.Column():
file_input_edge = gr.File(
label="Upload Medical Image",
file_types=[".dcm", ".png", ".jpg", ".jpeg"],
type="filepath"
)
with gr.Row():
modality_edge = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI")
window_edge = gr.Dropdown(
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"],
label="Windowing", value="Brain (Grey Matter)"
)
gr.Markdown("### Edge Detection Parameters")
edge_threshold_slider = gr.Slider(
10, 150, value=50, step=10,
label="Edge Threshold",
info="Lower = more edges detected"
)
dilation_size_slider = gr.Slider(
0, 10, value=3, step=1,
label="Dilation Size",
info="Connect nearby edges (0 = no dilation)"
)
with gr.Row():
colormap_edge = gr.Dropdown(
["spring", "cool", "hot", "viridis", "plasma"],
label="Colormap", value="spring"
)
transparency_edge = gr.Slider(
0.0, 1.0, value=0.5, step=0.1,
label="Transparency"
)
submit_edge_btn = gr.Button("🌊 Run Edge Segmentation", variant="primary")
with gr.Column():
output_edge = gr.Image(label="Edge Segmentation Result", type="filepath")
status_edge = gr.Textbox(label="Status", interactive=False)
# Single image processing
load_demo_btn.click(
fn=load_demo_file,
inputs=[],
outputs=[file_input, status_text]
)
submit_btn.click(
fn=process_with_status,
inputs=[file_input, text_input, modality_dropdown, window_dropdown],
outputs=[image_output, status_text]
)
# Detect subjects when files are uploaded
detect_subjects_btn.click(
fn=detect_subjects,
inputs=[files_input],
outputs=[subject_dropdown, status_batch_text]
)
# Interactive slice viewer
submit_batch_btn.click(
fn=process_slices_for_viewer,
inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
outputs=[current_slice_output, slice_slider, status_batch_text, slice_info_text, subject_dropdown, subject_info_text]
).then(
lambda max_val: gr.Slider(maximum=max(max_val, 1), interactive=True),
inputs=[slice_slider],
outputs=[slice_slider]
)
def update_slice(slice_num, files, selected_subject, prompt, mod, window):
result, info = navigate_slice(int(slice_num), files, selected_subject, prompt, mod, window)
return result, info
slice_slider.change(
fn=update_slice,
inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
outputs=[current_slice_output, slice_info_text]
)
def prev_slice(current, files, selected_subject, prompt, mod, window):
new_val = max(0, current - 1)
result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window)
return new_val, result, info
def next_slice(current, max_val, files, selected_subject, prompt, mod, window):
new_val = min(max_val, current + 1)
result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window)
return new_val, result, info
prev_btn.click(
fn=prev_slice,
inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
outputs=[slice_slider, current_slice_output, slice_info_text]
)
next_btn.click(
fn=next_slice,
inputs=[slice_slider, slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
outputs=[slice_slider, current_slice_output, slice_info_text]
)
# Gallery view
submit_gallery_btn.click(
fn=process_sequence,
inputs=[files_input_gallery, text_input_gallery, modality_dropdown_gallery, window_dropdown_gallery],
outputs=[gallery_output, status_gallery_text]
)
# Ground truth comparison
submit_gt_btn.click(
fn=process_with_ground_truth,
inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt],
outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text]
)
# Enhanced single image processing
def process_enhanced_wrapper(image_file, prompt_text, modality, window_type,
brightness, contrast, colormap, transparency, progress=gr.Progress()):
"""Wrapper to return both image and download file."""
result, status, metrics = process_with_progress(
image_file, prompt_text, modality, window_type,
brightness, contrast, colormap, transparency, progress
)
download_file = result if result else None
return result, status, metrics, download_file
submit_enh_btn.click(
fn=process_enhanced_wrapper,
inputs=[
file_input_enh, text_input_enh, modality_enh, window_enh,
brightness_slider, contrast_slider, colormap_dropdown, transparency_slider
],
outputs=[image_output_enh, progress_text, metrics_text, download_output]
)
# Enhanced batch processing
submit_enh_batch_btn.click(
fn=process_batch_enhanced,
inputs=[
files_input_enh_batch, text_input_enh_batch, modality_enh_batch, window_enh_batch,
brightness_slider_batch, contrast_slider_batch, colormap_dropdown_batch, transparency_slider_batch
],
outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text]
)
# Point prompt processing
submit_point_btn.click(
fn=process_with_point_prompt,
inputs=[file_input_point, point_x, point_y, modality_point, window_point, colormap_point, transparency_point],
outputs=[output_point, status_point]
)
# Box prompt processing
submit_box_btn.click(
fn=process_with_box_prompt,
inputs=[file_input_box, box_x1, box_y1, box_x2, box_y2, modality_box, window_box, colormap_box, transparency_box],
outputs=[output_box, status_box]
)
# ROI Statistics processing
submit_stats_btn.click(
fn=process_and_store_mask,
inputs=[file_input_stats, text_input_stats, modality_stats, window_stats],
outputs=[output_stats, status_stats, roi_stats_text]
)
# NIFTI Export
export_nifti_btn.click(
fn=export_last_mask_nifti,
inputs=[],
outputs=[nifti_download, status_stats]
)
# Save Annotation
save_annotation_btn.click(
fn=save_last_annotation,
inputs=[],
outputs=[annotation_download, status_stats]
)
# Load Annotation
load_annotation_btn.click(
fn=visualize_loaded_annotation,
inputs=[original_image_upload, annotation_upload],
outputs=[loaded_annotation_output, loaded_annotation_info]
)
# Multi-Mask processing
submit_multi_btn.click(
fn=process_multi_mask,
inputs=[file_input_multi, text_input_multi, modality_multi, window_multi, num_masks_slider],
outputs=[gallery_multi, status_multi, mask_info_multi]
)
# Auto-play functionality for slice viewer
def auto_play_slices(files, selected_subject, prompt, mod, window):
"""Auto-play through slices with a short delay."""
if not files:
yield None, "No slices loaded", 0
return
subject_groups = group_images_by_subject(files)
if selected_subject:
subject_id = selected_subject.split(" (")[0]
else:
subject_id = list(subject_groups.keys())[0] if subject_groups else None
if not subject_id or subject_id not in subject_groups:
yield None, "No slices loaded", 0
return
subject_files = subject_groups[subject_id]['files']
cache_key = f"{subject_id}_{len(subject_files)}_{prompt}_{mod}"
if cache_key not in processed_results_cache:
yield None, "Please process slices first", 0
return
results = processed_results_cache[cache_key]
for idx in range(len(results)):
slice_info = f"Slice {idx + 1}/{len(results)} ({subject_id}) - Auto-playing..."
yield results[idx], slice_info, idx
time.sleep(0.5) # 500ms delay between slices
auto_play_btn.click(
fn=auto_play_slices,
inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch],
outputs=[current_slice_output, slice_info_text, slice_slider]
)
# SAM-Medical-Imaging Inspired Features
# Automatic Mask Generator
submit_amg_btn.click(
fn=automatic_mask_generator,
inputs=[file_input_amg, modality_amg, window_amg, points_per_side, min_mask_area, colormap_amg],
outputs=[output_amg, status_amg, info_amg]
)
# Advanced Transforms
submit_transform_btn.click(
fn=process_with_advanced_transforms,
inputs=[file_input_transform, text_input_transform, modality_transform, window_transform,
target_size_slider, apply_clahe_checkbox, clahe_clip_slider, colormap_transform, transparency_transform],
outputs=[output_transform, status_transform]
)
# Edge-Based Segmentation
submit_edge_btn.click(
fn=edge_based_segmentation,
inputs=[file_input_edge, modality_edge, window_edge, edge_threshold_slider, dilation_size_slider,
colormap_edge, transparency_edge],
outputs=[output_edge, status_edge]
)
if __name__ == "__main__":
# Verify model is loaded before launching
if model is None or processor is None:
print("⚠️ WARNING: SAM 3 model failed to load!")
print("⚠️ The app will start but segmentation features will not work.")
print("⚠️ Please check:")
print(" 1. HF_TOKEN environment variable is set correctly")
print(" 2. transformers>=4.45.0 is installed")
print(" 3. Sufficient memory/GPU available")
else:
print("✅ SAM 3 model ready - app starting...")
demo.launch(server_name="0.0.0.0", server_port=7860)