""" NeuroSAM 3: Medical Image Segmentation App A Gradio app for segmenting medical images (CT/MRI) using SAM 3 """ import os import tempfile import zipfile import io import json import time from datetime import datetime import gradio as gr import spaces import torch import pydicom import numpy as np from PIL import Image, ImageEnhance, ImageDraw try: from transformers import Sam3Processor, Sam3Model SAM3_AVAILABLE = True except ImportError: print("⚠️ Warning: Sam3Processor/Sam3Model not found in transformers.") print("⚠️ SAM3 requires transformers from GitHub main branch.") print("⚠️ Install with: pip install git+https://github.com/huggingface/transformers.git") SAM3_AVAILABLE = False # Create dummy classes to prevent import errors Sam3Processor = None Sam3Model = None import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from scipy import ndimage from huggingface_hub import login # Try to import nibabel for NIFTI support (optional) try: import nibabel as nib NIBABEL_AVAILABLE = True except ImportError: NIBABEL_AVAILABLE = False print("⚠️ nibabel not available - NIFTI export disabled") # Hugging Face Token (must be set as HF_TOKEN environment variable in Space settings) hf_token = os.getenv("HF_TOKEN") if not hf_token: print("⚠️ WARNING: HF_TOKEN environment variable not set!") print("⚠️ Some features may not work. Please set HF_TOKEN in Space settings.") hf_token = None # Allow app to start, but model loading will fail gracefully else: # Login to Hugging Face Hub (only if token is provided) try: login(token=hf_token, add_to_git_credential=False) except Exception as e: print(f"⚠️ Could not login to HF Hub (non-critical): {e}") # Load SAM 3 Model print("🧠 Loading SAM 3 Model...") # IMPORTANT: For HF Spaces with Stateless GPU, load model on CPU in main process # Model will be moved to GPU inside @spaces.GPU decorated functions model = None processor = None if not SAM3_AVAILABLE: print("❌ SAM 3 classes not available in transformers library.") print("❌ Install with: pip install git+https://github.com/huggingface/transformers.git") print("⚠️ App will start but segmentation features will be disabled.") else: # SAM 3 model identifier - matching official implementation SAM_MODEL_ID = "facebook/sam3" if hf_token is None: print("⚠️ Cannot load model: HF_TOKEN not set") model = None processor = None else: try: # Load model on CPU to avoid CUDA initialization in main process (for HF Spaces Stateless GPU) # Model will be moved to GPU inside @spaces.GPU decorated functions model = Sam3Model.from_pretrained( SAM_MODEL_ID, torch_dtype=torch.float32, # Load as float32 on CPU token=hf_token ) processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=hf_token) model.eval() print(f"✅ SAM 3 Model Loaded Successfully on CPU! ({SAM_MODEL_ID})") print("💡 Model will be moved to GPU when inference is called") except Exception as e: print(f"⚠️ Failed to load SAM 3 model: {e}") print("Ensure you have:") print(" 1. transformers from GitHub main branch for SAM 3 support") print(" Install with: pip install git+https://github.com/huggingface/transformers.git") print(" 2. Valid Hugging Face token with access to SAM 3") print(" 3. Sufficient memory for the model") print("⚠️ App will start but segmentation features will be disabled until model loads.") # Don't raise - allow app to start and show error in UI model = None processor = None @spaces.GPU(duration=60) def run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0): """ Run SAM 3 inference - optimized for medical imaging. Args: pil_image: PIL Image to segment prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull") threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images). Lower values (0.0-0.3) are more permissive and better for subtle features. Higher values (0.5-1.0) require high confidence, may miss detections. mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images). Lower values preserve more detail. Higher values create sharper masks. Medical images often benefit from 0.0 to capture subtle boundaries. Returns: results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed Note: Default thresholds (0.1, 0.0) are optimized for medical imaging where features may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5) may be more appropriate. """ if model is None or processor is None: print("❌ Model not loaded - please check HF_TOKEN and model availability") raise ValueError("SAM 3 model not loaded. Please check that HF_TOKEN is set correctly and the model is accessible.") def to_serializable(obj): """ Convert all tensors to numpy arrays or Python primitives for safe serialization. This ensures NO PyTorch tensors (CPU or CUDA) are in the return value. """ if isinstance(obj, torch.Tensor): # Convert to numpy array (works for both CPU and CUDA tensors) result = obj.cpu().numpy() print(f"🔄 Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}") return result elif isinstance(obj, dict): return {k: to_serializable(v) for k, v in obj.items()} elif isinstance(obj, list): return [to_serializable(item) for item in obj] elif isinstance(obj, tuple): return tuple(to_serializable(item) for item in obj) elif isinstance(obj, (int, float, str, bool, type(None))): return obj elif hasattr(obj, 'item'): # numpy scalar return obj.item() else: # For unknown types, try to convert to string representation print(f"⚠️ Unknown type encountered: {type(obj)}, converting to string") return str(obj) try: # Determine device and move model to GPU if available (CUDA initialization happens here, inside @spaces.GPU) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🔧 Using device: {device}") # Move model to device and set appropriate dtype # Note: For nn.Module, .to() modifies in-place and returns self # IMPORTANT: @spaces.GPU ensures sequential execution - requests are queued and processed # one at a time, so there's NO concurrent access to the model. This makes in-place # modification safe despite model being a global variable. dtype = torch.float16 if device == "cuda" else torch.float32 model.to(device=device, dtype=dtype) print(f"✅ Model moved to {device} with dtype {dtype}") # Prepare inputs - matching official implementation inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device) # Convert float32 inputs to model dtype (float16 for GPU) - matching official implementation for key in inputs: if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32: inputs[key] = inputs[key].to(model.dtype) with torch.no_grad(): outputs = model(**inputs) print(f"🧠 Inference complete, processing results...") # Post-process using processor method - matching official implementation results = processor.post_process_instance_segmentation( outputs, threshold=threshold, mask_threshold=mask_threshold, target_sizes=inputs.get("original_sizes").tolist() if "original_sizes" in inputs else [pil_image.size[::-1]] )[0] # Get first batch result print(f"📊 Results type: {type(results)}") if isinstance(results, dict): print(f"📊 Results keys: {results.keys()}") for key, value in results.items(): print(f" - {key}: type={type(value)}") if isinstance(value, torch.Tensor): print(f" tensor device={value.device}, shape={value.shape}, dtype={value.dtype}") elif isinstance(value, list) and len(value) > 0: print(f" list length={len(value)}, first item type={type(value[0])}") if isinstance(value[0], torch.Tensor): print(f" first tensor device={value[0].device}") # CRITICAL: Convert ALL tensors to numpy arrays before returning # This ensures NO PyTorch tensors (CPU or CUDA) cross the process boundary # Numpy arrays are safely serializable without triggering CUDA init print(f"🔄 Converting all tensors to numpy arrays...") results = to_serializable(results) print(f"✅ All tensors converted to serializable format") # Move model back to CPU to free GPU memory (important for Spaces) model.to("cpu") print(f"✅ Model moved back to CPU") return results except Exception as e: print(f"❌ Error during SAM 3 inference: {e}") import traceback traceback.print_exc() # Make sure to move model back to CPU even on error if model is not None: try: model.to("cpu") except RuntimeError as cleanup_error: print(f"⚠️ Could not move model back to CPU: {cleanup_error}") return None # Create Sample DICOM File for Demo demo_dicom_path = "demo_brain_mri.dcm" demo_file_available = False try: from pydicom.data import get_testdata_file test_file = get_testdata_file("MR_small.dcm") if test_file and os.path.exists(test_file): import shutil shutil.copy(test_file, demo_dicom_path) demo_file_available = True print(f"✅ Demo file ready: {demo_dicom_path}") except: try: # Create synthetic DICOM file from pydicom.dataset import FileDataset, FileMetaDataset from pydicom.uid import generate_uid from datetime import datetime synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16) center_x, center_y = 128, 128 y, x = np.ogrid[:256, :256] mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2 synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255) file_meta = FileMetaDataset() file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4' file_meta.MediaStorageSOPInstanceUID = generate_uid() file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1' ds = FileDataset(demo_dicom_path, {}, file_meta=file_meta, preamble=b"\x00" * 128) ds.PatientName = "Demo^Patient" ds.PatientID = "DEMO001" ds.Modality = "MR" ds.Rows = 256 ds.Columns = 256 ds.BitsAllocated = 16 ds.BitsStored = 16 ds.HighBit = 15 ds.SamplesPerPixel = 1 ds.PixelRepresentation = 0 ds.PhotometricInterpretation = "MONOCHROME2" ds.PixelSpacing = [1.0, 1.0] ds.RescaleIntercept = "0" ds.RescaleSlope = "1" ds.PixelData = synthetic_image.tobytes() ds.save_as(demo_dicom_path, write_like_original=False) demo_file_available = True print(f"✅ Synthetic demo file created: {demo_dicom_path}") except Exception as e: print(f"⚠️ Could not create demo file: {e}") def compare_with_ground_truth(pred_mask, gt_mask_path): """Compare SAM 3 prediction with ground truth mask and return comparison metrics.""" try: gt_mask = Image.open(gt_mask_path) gt_array = np.array(gt_mask.convert('L')) > 127 # Binarize # Resize prediction mask to match ground truth if needed if pred_mask.shape != gt_array.shape: from PIL import Image as PILImage pred_pil = PILImage.fromarray((pred_mask * 255).astype(np.uint8)) pred_pil = pred_pil.resize(gt_mask.size, PILImage.NEAREST) pred_mask = np.array(pred_pil) > 127 # Calculate metrics intersection = np.logical_and(pred_mask, gt_array).sum() union = np.logical_or(pred_mask, gt_array).sum() dice_score = (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0 iou_score = intersection / union if union > 0 else 0.0 # Create comparison visualization fig, axes = plt.subplots(1, 3, figsize=(15, 5)) axes[0].imshow(pred_mask, cmap='spring') axes[0].set_title('SAM 3 Prediction') axes[0].axis('off') axes[1].imshow(gt_array, cmap='cool') axes[1].set_title('Ground Truth') axes[1].axis('off') # Overlay comparison comparison = np.zeros((*pred_mask.shape, 3)) comparison[pred_mask & gt_array] = [0, 1, 0] # Green: True Positive comparison[pred_mask & ~gt_array] = [1, 0, 0] # Red: False Positive comparison[~pred_mask & gt_array] = [0, 0, 1] # Blue: False Negative axes[2].imshow(comparison) axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}') axes[2].axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=100) plt.close() return output_path, dice_score, iou_score except Exception as e: print(f"⚠️ Error comparing with ground truth: {e}") return None, 0.0, 0.0 def process_medical_image(image_file, prompt_text, modality, window_type, return_mask=False): """Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3. Args: image_file: Path to image file prompt_text: Text prompt for segmentation modality: CT or MRI window_type: Windowing strategy return_mask: If True, also return the binary mask array Returns: Path to output image, and optionally the mask array """ if model is None or processor is None: print("❌ Error: Model not loaded.") return None if image_file is None: return None if not prompt_text or not prompt_text.strip(): prompt_text = "brain" try: file_path = image_file if isinstance(image_file, str) else str(image_file) if not os.path.exists(file_path): print(f"❌ Error: File not found at {file_path}") return None # Detect file type file_ext = os.path.splitext(file_path)[1].lower() is_dicom = file_ext == '.dcm' if is_dicom: # Process DICOM file ds = pydicom.dcmread(file_path) if not hasattr(ds, 'pixel_array'): print("❌ Error: DICOM file does not contain pixel data.") return None raw = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_hu = raw * slope + intercept # Apply Windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI img_min = np.percentile(img_hu, 1) img_max = np.percentile(img_hu, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_hu) img_max = np.max(img_hu) img_range = img_max - img_min if img_range <= 0: return None img_windowed = (img_hu - img_min) / img_range img_windowed = np.clip(img_windowed, 0, 1) img_uint8 = (img_windowed * 255).astype(np.uint8) if len(img_uint8.shape) == 2: pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.fromarray(img_uint8) else: # Process standard image file (PNG, JPG, etc.) pil_image = Image.open(file_path) # Convert to RGB if needed if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # Convert to numpy for normalization img_array = np.array(pil_image) # Handle grayscale images if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize image (percentile-based for MRI-like processing) img_float = img_array.astype(np.float32) if modality == "CT": # For CT-like processing, use windowing if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI - use percentile normalization img_min = np.percentile(img_float, 1) img_max = np.percentile(img_float, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_float) img_max = np.max(img_float) img_range = img_max - img_min if img_range <= 0: return None img_normalized = (img_float - img_min) / img_range img_normalized = np.clip(img_normalized, 0, 1) img_uint8 = (img_normalized * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8.astype(np.uint8)) # Run SAM 3 Inference - using helper function matching official implementation # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) if results is None: return None # Draw Masks on Image - matching official implementation format plt.figure(figsize=(10, 10)) plt.imshow(pil_image) final_mask = None if 'masks' in results and results['masks'] is not None: masks = results['masks'] # List of mask tensors from post_process_instance_segmentation scores = results.get('scores', []) if len(masks) > 0: # Combine all masks into one (or use first mask) # Convert tensors to numpy and combine mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) mask_arrays.append(mask_np) # Combine all masks if len(mask_arrays) > 0: final_mask = np.any(mask_arrays, axis=0) plt.imshow(final_mask, alpha=0.5, cmap='spring') else: print("⚠️ Warning: No valid masks found.") else: print("⚠️ Warning: No masks in results.") else: print("⚠️ Warning: No masks in results.") plt.axis('off') plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() if return_mask: return output_path, final_mask return output_path except pydicom.errors.InvalidDicomError as e: print(f"❌ Error: Invalid DICOM file format. {e}") return None except Exception as e: print(f"❌ Error processing image: {e}") import traceback traceback.print_exc() return None def process_medical_image_enhanced(image_file, prompt_text, modality, window_type, brightness=1.0, contrast=1.0, colormap='spring', transparency=0.5, return_mask=False): """Enhanced version of process_medical_image with image adjustments and visualization options. Args: image_file: Path to image file prompt_text: Text prompt for segmentation modality: CT or MRI window_type: Windowing strategy brightness: Brightness multiplier (0.5-2.0) contrast: Contrast multiplier (0.5-2.0) colormap: Matplotlib colormap name transparency: Mask overlay transparency (0.0-1.0) return_mask: If True, also return the binary mask array Returns: Path to output image, and optionally the mask array """ if model is None or processor is None: print("❌ Error: Model not loaded.") return None if image_file is None: return None if not prompt_text or not prompt_text.strip(): prompt_text = "brain" try: file_path = image_file if isinstance(image_file, str) else str(image_file) if not os.path.exists(file_path): print(f"❌ Error: File not found at {file_path}") return None # Detect file type file_ext = os.path.splitext(file_path)[1].lower() is_dicom = file_ext == '.dcm' if is_dicom: # Process DICOM file ds = pydicom.dcmread(file_path) if not hasattr(ds, 'pixel_array'): print("❌ Error: DICOM file does not contain pixel data.") return None raw = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_hu = raw * slope + intercept # Apply Windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI img_min = np.percentile(img_hu, 1) img_max = np.percentile(img_hu, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_hu) img_max = np.max(img_hu) img_range = img_max - img_min if img_range <= 0: return None img_windowed = (img_hu - img_min) / img_range img_windowed = np.clip(img_windowed, 0, 1) img_uint8 = (img_windowed * 255).astype(np.uint8) if len(img_uint8.shape) == 2: pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.fromarray(img_uint8) else: # Process standard image file (PNG, JPG, etc.) pil_image = Image.open(file_path) # Convert to RGB if needed if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # Convert to numpy for normalization img_array = np.array(pil_image) # Handle grayscale images if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize image (percentile-based for MRI-like processing) img_float = img_array.astype(np.float32) if modality == "CT": # For CT-like processing, use windowing if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI - use percentile normalization img_min = np.percentile(img_float, 1) img_max = np.percentile(img_float, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_float) img_max = np.max(img_float) img_range = img_max - img_min if img_range <= 0: return None img_normalized = (img_float - img_min) / img_range img_normalized = np.clip(img_normalized, 0, 1) img_uint8 = (img_normalized * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8.astype(np.uint8)) # Apply brightness and contrast adjustments enhancer = ImageEnhance.Brightness(pil_image) pil_image = enhancer.enhance(brightness) enhancer = ImageEnhance.Contrast(pil_image) pil_image = enhancer.enhance(contrast) # Run SAM 3 Inference - using helper function matching official implementation # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) if results is None: return None # Draw Masks on Image with enhanced visualization - matching official implementation format plt.figure(figsize=(10, 10)) plt.imshow(pil_image) final_mask = None if 'masks' in results and results['masks'] is not None: masks = results['masks'] # List of mask tensors from post_process_instance_segmentation scores = results.get('scores', []) if len(masks) > 0: # Combine all masks into one mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) mask_arrays.append(mask_np) # Combine all masks if len(mask_arrays) > 0: final_mask = np.any(mask_arrays, axis=0) plt.imshow(final_mask, alpha=transparency, cmap=colormap) else: print("⚠️ Warning: No valid masks found.") else: print("⚠️ Warning: No masks in results.") else: print("⚠️ Warning: No masks in results.") plt.axis('off') plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() if return_mask: return output_path, final_mask return output_path except pydicom.errors.InvalidDicomError as e: print(f"❌ Error: Invalid DICOM file format. {e}") return None except Exception as e: print(f"❌ Error processing image: {e}") import traceback traceback.print_exc() return None def process_with_progress(image_file, prompt_text, modality, window_type, brightness=1.0, contrast=1.0, colormap='spring', transparency=0.5, progress=gr.Progress()): """Process with progress indicator.""" if model is None or processor is None: return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" progress(0, desc="Starting...") progress(0.1, desc="Reading image...") result = process_medical_image_enhanced( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency ) progress(0.8, desc="Processing segmentation...") if result is None: progress(1.0, desc="Failed!") return None, "❌ Processing failed. Check console for error details.", "" progress(1.0, desc="Complete!") # Calculate metrics metrics = f"Prompt: {prompt_text}\nModality: {modality}\nProcessed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" return result, "✅ Segmentation complete!", metrics def create_downloadable_result(output_path, prompt_text): """Create a downloadable file with metadata.""" return output_path # Gradio File component handles downloads automatically def process_batch_enhanced(image_files, prompt_text, modality, window_type, brightness=1.0, contrast=1.0, colormap='spring', transparency=0.5, progress=gr.Progress()): """Process multiple images with enhanced features and create ZIP download.""" if model is None or processor is None: return [], None, "❌ Error: Model not loaded." if not image_files: return [], None, "⚠️ Please upload medical image files." # Handle single file or list of files if isinstance(image_files, str): image_files = [image_files] results = [] total = len(image_files) for idx, image_file in enumerate(image_files): if image_file is None: continue progress((idx + 1) / total, desc=f"Processing image {idx + 1}/{total}...") result = process_medical_image_enhanced( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency ) if result: results.append(result) if not results: return [], None, "❌ No images were processed successfully." # Create ZIP file with all results zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: for idx, result_path in enumerate(results): if os.path.exists(result_path): filename = f"segmentation_{idx + 1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" zip_file.write(result_path, filename) zip_buffer.seek(0) zip_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') zip_path.write(zip_buffer.read()) zip_path.close() status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download." return results, zip_path.name, status # ============================================================================ # ENHANCED FEATURES - Auto-play, Point/Box Prompts, ROI Stats, NIFTI Export # ============================================================================ # Global state for auto-play auto_play_state = {"running": False, "current_idx": 0} def calculate_roi_statistics(image_file, mask, modality): """Calculate ROI statistics from the segmented region. Returns: dict: Statistics including area, mean intensity, std, min, max, centroid """ if mask is None or not isinstance(mask, np.ndarray): return { "error": "No valid mask available", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } try: # Load original image for intensity statistics file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept else: img = Image.open(file_path) if img.mode == 'RGB': img = img.convert('L') # Convert to grayscale for intensity stats img_array = np.array(img).astype(np.float32) # Resize mask if needed if mask.shape != img_array.shape: from scipy.ndimage import zoom zoom_factors = (img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1]) mask = zoom(mask.astype(float), zoom_factors, order=0) > 0.5 # Calculate statistics mask_bool = mask.astype(bool) total_pixels = mask.size roi_pixels = np.sum(mask_bool) if roi_pixels == 0: return { "error": "No pixels in ROI", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } roi_intensities = img_array[mask_bool] # Calculate centroid labeled_mask, num_features = ndimage.label(mask_bool) centroid = ndimage.center_of_mass(mask_bool) # Calculate bounding box rows = np.any(mask_bool, axis=1) cols = np.any(mask_bool, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] stats = { "area_pixels": int(roi_pixels), "area_percentage": float(roi_pixels / total_pixels * 100), "mean_intensity": float(np.mean(roi_intensities)), "std_intensity": float(np.std(roi_intensities)), "min_intensity": float(np.min(roi_intensities)), "max_intensity": float(np.max(roi_intensities)), "centroid": (float(centroid[1]), float(centroid[0])), # (x, y) "bounding_box": (int(cmin), int(rmin), int(cmax), int(rmax)), # (x1, y1, x2, y2) "num_components": num_features } # Add HU statistics for CT if modality == "CT": stats["mean_hu"] = stats["mean_intensity"] stats["std_hu"] = stats["std_intensity"] return stats except Exception as e: print(f"Error calculating ROI statistics: {e}") return {"error": str(e)} def format_roi_statistics(stats): """Format ROI statistics as a readable string.""" if "error" in stats and stats.get("area_pixels", 0) == 0: return f"⚠️ {stats.get('error', 'No statistics available')}" text = "📊 **ROI Statistics**\n\n" text += f"**Area:** {stats['area_pixels']:,} pixels ({stats['area_percentage']:.2f}%)\n" text += f"**Intensity:** {stats['mean_intensity']:.2f} ± {stats['std_intensity']:.2f}\n" text += f"**Range:** [{stats['min_intensity']:.2f}, {stats['max_intensity']:.2f}]\n" text += f"**Centroid:** ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f})\n" text += f"**Bounding Box:** {stats['bounding_box']}\n" text += f"**Components:** {stats.get('num_components', 1)}" if "mean_hu" in stats: text += f"\n\n**CT (Hounsfield Units):**\n" text += f"Mean HU: {stats['mean_hu']:.1f} ± {stats['std_hu']:.1f}" return text def process_with_roi_stats(image_file, prompt_text, modality, window_type): """Process image and return both segmentation and ROI statistics.""" if model is None or processor is None: return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result is None: return None, "❌ Processing failed.", "" # Calculate ROI statistics stats = calculate_roi_statistics(image_file, mask, modality) stats_text = format_roi_statistics(stats) return result, "✅ Segmentation complete!", stats_text def process_with_point_prompt(image_file, point_x, point_y, modality, window_type, colormap='spring', transparency=0.5): """Process image with a point prompt for segmentation. Note: This simulates point-based prompting by using the point location as a seed for region-based segmentation. """ if model is None or processor is None: return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Normalize img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] # Clamp point coordinates point_x = max(0, min(int(point_x), w - 1)) point_y = max(0, min(int(point_y), h - 1)) # Create a prompt based on the point location prompt_text = f"segment region at point" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Select mask containing the point for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Resize to image size mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 if mask_resized[point_y, point_x]: final_mask = mask_resized break # If no mask contains the point, use first mask if final_mask is None and len(masks) > 0: mask = masks[0] if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) final_mask = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 # Draw result with point marker plt.figure(figsize=(10, 10)) plt.imshow(pil_image) if final_mask is not None: plt.imshow(final_mask, alpha=transparency, cmap=colormap) # Draw point marker plt.scatter([point_x], [point_y], c='red', s=200, marker='+', linewidths=3) plt.scatter([point_x], [point_y], c='red', s=100, marker='o', facecolors='none', linewidths=2) plt.axis('off') plt.title(f"Point Prompt Segmentation at ({point_x}, {point_y})", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, f"✅ Point-based segmentation at ({point_x}, {point_y})" except Exception as e: print(f"Error in point prompt processing: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, colormap='spring', transparency=0.5): """Process image with a bounding box prompt for segmentation.""" if model is None or processor is None: return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] # Ensure box coordinates are valid x1, x2 = min(x1, x2), max(x1, x2) y1, y2 = min(y1, y2), max(y1, y2) x1, y1 = max(0, int(x1)), max(0, int(y1)) x2, y2 = min(w, int(x2)), min(h, int(y2)) prompt_text = "segment region in bounding box" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Combine all masks mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Resize to image size mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 mask_arrays.append(mask_resized) if len(mask_arrays) > 0: combined = np.any(mask_arrays, axis=0) # Create box mask and intersect box_mask = np.zeros((h, w), dtype=bool) box_mask[y1:y2, x1:x2] = True final_mask = combined & box_mask # Draw result with box plt.figure(figsize=(10, 10)) plt.imshow(pil_image) if final_mask is not None: plt.imshow(final_mask, alpha=transparency, cmap=colormap) # Draw bounding box rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='red', facecolor='none') plt.gca().add_patch(rect) plt.axis('off') plt.title(f"Box Prompt Segmentation [{x1}, {y1}, {x2}, {y2}]", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, f"✅ Box-based segmentation at [{x1}, {y1}, {x2}, {y2}]" except Exception as e: print(f"Error in box prompt processing: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks=3): """Process image and return multiple mask candidates with confidence scores.""" if model is None or processor is None: return [], "❌ Error: Model not loaded.", "" if image_file is None: return [], "⚠️ Please upload a medical image file.", "" try: file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') if not prompt_text or not prompt_text.strip(): prompt_text = "brain" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out sam_results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) results = [] mask_info = [] if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: masks = sam_results['masks'] # List of mask tensors scores = sam_results.get('scores', []) # List of scores num_available = len(masks) num_to_show = min(num_masks, num_available) colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma'] for i in range(num_to_show): mask = masks[i] if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Convert to boolean if mask_np.dtype != bool: mask_np = mask_np > 0.5 score = scores[i].item() if i < len(scores) and isinstance(scores[i], torch.Tensor) else (scores[i] if i < len(scores) else 0.5) # Create visualization plt.figure(figsize=(8, 8)) plt.imshow(pil_image) plt.imshow(mask_np, alpha=0.5, cmap=colormaps[i % len(colormaps)]) plt.axis('off') plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() results.append(output_path) mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask_np):,} pixels") status = f"✅ Generated {len(results)} mask candidate(s)" info = "\n".join(mask_info) if mask_info else "No mask information available" return results, status, info except Exception as e: print(f"Error in multi-mask processing: {e}") import traceback traceback.print_exc() return [], f"❌ Error: {str(e)}", "" def export_to_nifti(image_file, mask, output_name="segmentation"): """Export segmentation mask to NIFTI format. Returns: str: Path to the exported NIFTI file, or None if export failed """ if not NIBABEL_AVAILABLE: return None, "⚠️ NIFTI export not available - nibabel not installed" if mask is None or not isinstance(mask, np.ndarray): return None, "⚠️ No valid mask to export" try: # Convert mask to appropriate format mask_data = mask.astype(np.float32) # Create NIFTI image # Use identity affine (1mm isotropic) affine = np.eye(4) # Try to get spacing from DICOM if available if image_file: file_path = image_file if isinstance(image_file, str) else str(image_file) if file_path.lower().endswith('.dcm'): try: ds = pydicom.dcmread(file_path, stop_before_pixels=True) pixel_spacing = getattr(ds, 'PixelSpacing', [1.0, 1.0]) slice_thickness = getattr(ds, 'SliceThickness', 1.0) affine[0, 0] = float(pixel_spacing[0]) affine[1, 1] = float(pixel_spacing[1]) affine[2, 2] = float(slice_thickness) except: pass nifti_img = nib.Nifti1Image(mask_data, affine) # Save to temp file output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') output_path = output_file.name output_file.close() nib.save(nifti_img, output_path) return output_path, f"✅ Exported to NIFTI: {output_path}" except Exception as e: print(f"Error exporting to NIFTI: {e}") return None, f"❌ Export failed: {str(e)}" def save_annotation(image_file, mask, prompt_text, modality, stats=None): """Save annotation to a JSON file for later loading.""" if mask is None: return None, "⚠️ No annotation to save" try: annotation = { "timestamp": datetime.now().isoformat(), "image_file": os.path.basename(image_file) if image_file else "unknown", "prompt": prompt_text, "modality": modality, "mask_shape": list(mask.shape), "mask_sum": int(np.sum(mask)), "mask_base64": None, # We'll store as binary in a separate file "statistics": stats if stats else {} } # Save mask as numpy file mask_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') mask_path = mask_file.name mask_file.close() np.savez_compressed(mask_path, mask=mask) # Save annotation JSON json_file = tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') json_path = json_file.name annotation["mask_file"] = mask_path json.dump(annotation, json_file, indent=2) json_file.close() # Create ZIP with both files zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(json_path, 'annotation.json') zf.write(mask_path, 'mask.npz') zip_buffer.seek(0) zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') zip_path = zip_file.name zip_file.write(zip_buffer.read()) zip_file.close() return zip_path, f"✅ Annotation saved: {os.path.basename(zip_path)}" except Exception as e: print(f"Error saving annotation: {e}") return None, f"❌ Save failed: {str(e)}" def load_annotation(annotation_file): """Load a previously saved annotation.""" if annotation_file is None: return None, None, "⚠️ No file selected" try: file_path = annotation_file if isinstance(annotation_file, str) else str(annotation_file) if file_path.endswith('.zip'): # Extract ZIP with zipfile.ZipFile(file_path, 'r') as zf: # Read annotation JSON with zf.open('annotation.json') as f: annotation = json.load(f) # Extract mask file mask_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') mask_temp.write(zf.read('mask.npz')) mask_temp.close() mask_data = np.load(mask_temp.name) mask = mask_data['mask'] info = f"📋 **Loaded Annotation**\n" info += f"Image: {annotation.get('image_file', 'unknown')}\n" info += f"Prompt: {annotation.get('prompt', 'N/A')}\n" info += f"Modality: {annotation.get('modality', 'N/A')}\n" info += f"Saved: {annotation.get('timestamp', 'N/A')}\n" info += f"Mask size: {annotation.get('mask_sum', 0):,} pixels" return mask, annotation, info else: return None, None, "⚠️ Invalid file format. Please upload a .zip annotation file." except Exception as e: print(f"Error loading annotation: {e}") return None, None, f"❌ Load failed: {str(e)}" def visualize_loaded_annotation(image_file, annotation_file, colormap='spring', transparency=0.5): """Visualize a loaded annotation on the original image.""" mask, annotation, info = load_annotation(annotation_file) if mask is None: return None, info if image_file is None: return None, "⚠️ Please upload the original image to visualize" try: file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') # Resize mask if needed w, h = pil_image.size if mask.shape != (h, w): mask = np.array(Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))) > 127 # Visualize plt.figure(figsize=(10, 10)) plt.imshow(pil_image) plt.imshow(mask, alpha=transparency, cmap=colormap) plt.axis('off') plt.title(f"Loaded Annotation: {annotation.get('prompt', 'N/A')}", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, info except Exception as e: print(f"Error visualizing annotation: {e}") return None, f"❌ Visualization failed: {str(e)}" # Store last mask for export/save operations last_processed_mask = {"mask": None, "image_file": None, "prompt": None, "modality": None} def process_and_store_mask(image_file, prompt_text, modality, window_type): """Process image and store mask for export/save operations.""" result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result and mask is not None: last_processed_mask["mask"] = mask last_processed_mask["image_file"] = image_file last_processed_mask["prompt"] = prompt_text last_processed_mask["modality"] = modality # Calculate stats stats = calculate_roi_statistics(image_file, mask, modality) stats_text = format_roi_statistics(stats) return result, "✅ Segmentation complete! Ready for export.", stats_text else: return result, "❌ Processing failed.", "" def export_last_mask_nifti(): """Export the last processed mask to NIFTI.""" if last_processed_mask["mask"] is None: return None, "⚠️ No mask to export. Process an image first." return export_to_nifti( last_processed_mask["image_file"], last_processed_mask["mask"] ) # ============================================================================ # SAM-MEDICAL-IMAGING UTILITIES (Inspired by amine0110/SAM-Medical-Imaging) # Automatic Mask Generator, Advanced Transforms, Grid-based Segmentation # ============================================================================ class ResizeLongestSide: """ Resizes images to the longest side target length while maintaining aspect ratio. Inspired by SAM-Medical-Imaging transforms.py """ def __init__(self, target_length: int = 1024): self.target_length = target_length def apply_image(self, image: np.ndarray) -> np.ndarray: """Resize image maintaining aspect ratio.""" h, w = image.shape[:2] scale = self.target_length / max(h, w) new_h, new_w = int(h * scale), int(w * scale) pil_image = Image.fromarray(image) pil_image = pil_image.resize((new_w, new_h), Image.BILINEAR) return np.array(pil_image) def apply_coords(self, coords: np.ndarray, original_size: tuple) -> np.ndarray: """Resize coordinates to match resized image.""" old_h, old_w = original_size scale = self.target_length / max(old_h, old_w) return coords * scale def apply_boxes(self, boxes: np.ndarray, original_size: tuple) -> np.ndarray: """Resize bounding boxes to match resized image.""" boxes = boxes.copy() boxes[..., :2] = self.apply_coords(boxes[..., :2], original_size) boxes[..., 2:] = self.apply_coords(boxes[..., 2:], original_size) return boxes def generate_grid_points(image_size: tuple, points_per_side: int = 32) -> np.ndarray: """ Generate a grid of points for automatic mask generation. Inspired by SAM AMG (Automatic Mask Generator). Args: image_size: (height, width) of the image points_per_side: Number of points per side of the grid Returns: Array of (x, y) point coordinates """ h, w = image_size # Generate evenly spaced points x_coords = np.linspace(0, w - 1, points_per_side) y_coords = np.linspace(0, h - 1, points_per_side) # Create grid xx, yy = np.meshgrid(x_coords, y_coords) points = np.stack([xx.flatten(), yy.flatten()], axis=1) return points def automatic_mask_generator(image_file, modality, window_type, points_per_side=16, min_mask_area=100, colormap='tab20', progress=gr.Progress()): """ Automatic Mask Generator (AMG) - Generate masks for entire image without prompts. Uses a grid of points to query the model and generates masks automatically. Inspired by SAM-Medical-Imaging's amg.py """ if model is None or processor is None: return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" try: progress(0.1, desc="Loading image...") # Load and preprocess image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Apply windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] progress(0.2, desc="Generating grid points...") # Generate grid of points grid_points = generate_grid_points((h, w), points_per_side) total_points = len(grid_points) # Collect all masks all_masks = [] all_scores = [] progress(0.3, desc=f"Processing {total_points} points...") # Use different prompts to generate diverse masks prompts = ["anatomical structure", "region", "tissue", "organ", "segment"] for prompt_idx, prompt in enumerate(prompts): progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...") try: # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out sam_results = run_sam3_inference(pil_image, prompt, threshold=0.1, mask_threshold=0.0) if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: masks = sam_results['masks'] # List of mask tensors for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Convert to boolean if mask_np.dtype != bool: mask_np = mask_np > 0.5 # Filter by minimum area mask_area = np.sum(mask_np) if mask_area >= min_mask_area: # Resize mask to image size mask_resized = np.array( Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h)) ) > 127 all_masks.append(mask_resized) all_scores.append(mask_area) except Exception as e: print(f"Error with prompt '{prompt}': {e}") continue progress(0.85, desc="Combining masks...") if not all_masks: return None, "⚠️ No masks generated. Try different parameters.", "" # Remove duplicate/overlapping masks using Non-Maximum Suppression unique_masks = [] for mask in all_masks: is_duplicate = False for existing in unique_masks: # Check IoU intersection = np.logical_and(mask, existing).sum() union = np.logical_or(mask, existing).sum() iou = intersection / union if union > 0 else 0 if iou > 0.8: # High overlap threshold is_duplicate = True break if not is_duplicate: unique_masks.append(mask) progress(0.9, desc="Creating visualization...") # Create visualization with all masks plt.figure(figsize=(12, 12)) plt.imshow(pil_image) # Create colored overlay for all masks cmap = plt.cm.get_cmap(colormap) for idx, mask in enumerate(unique_masks): color = cmap(idx / max(len(unique_masks), 1))[:3] colored_mask = np.zeros((*mask.shape, 4)) colored_mask[mask] = (*color, 0.4) plt.imshow(colored_mask) plt.axis('off') plt.title(f"Automatic Mask Generation: {len(unique_masks)} regions detected", fontsize=14) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150) plt.close() progress(1.0, desc="Complete!") # Generate info text info_text = f"📊 **AMG Results**\n\n" info_text += f"**Total Regions:** {len(unique_masks)}\n" info_text += f"**Grid Points:** {points_per_side}x{points_per_side} = {points_per_side**2}\n" info_text += f"**Min Area Filter:** {min_mask_area} pixels\n\n" for idx, mask in enumerate(unique_masks[:10]): # Show first 10 area = np.sum(mask) percentage = area / (h * w) * 100 info_text += f"Region {idx+1}: {area:,} pixels ({percentage:.2f}%)\n" if len(unique_masks) > 10: info_text += f"\n... and {len(unique_masks) - 10} more regions" return output_path, f"✅ AMG Complete! Found {len(unique_masks)} regions.", info_text except Exception as e: print(f"Error in AMG: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}", "" def process_with_advanced_transforms(image_file, prompt_text, modality, window_type, target_size=1024, apply_clahe=False, clahe_clip=2.0, colormap='spring', transparency=0.5): """ Process image with advanced transforms from SAM-Medical-Imaging. - ResizeLongestSide: Maintains aspect ratio - CLAHE: Contrast Limited Adaptive Histogram Equalization (optional) """ if model is None or processor is None: return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Apply windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) else: img = Image.open(file_path) if img.mode != 'L': img = img.convert('L') img_uint8 = np.array(img) original_size = img_uint8.shape[:2] # Apply CLAHE if requested if apply_clahe: try: from scipy.ndimage import uniform_filter # Simple CLAHE-like enhancement using local contrast local_mean = uniform_filter(img_uint8.astype(float), size=50) local_std = np.sqrt(uniform_filter(img_uint8.astype(float)**2, size=50) - local_mean**2 + 1e-8) # Enhance contrast enhanced = (img_uint8 - local_mean) / (local_std + 1e-8) * clahe_clip enhanced = np.clip(enhanced * 30 + 128, 0, 255).astype(np.uint8) img_uint8 = enhanced except Exception as e: print(f"CLAHE enhancement failed: {e}") # Apply ResizeLongestSide transform transform = ResizeLongestSide(target_size) if len(img_uint8.shape) == 2: img_uint8_3ch = np.stack([img_uint8] * 3, axis=-1) else: img_uint8_3ch = img_uint8 img_resized = transform.apply_image(img_uint8_3ch) pil_image = Image.fromarray(img_resized) if not prompt_text or not prompt_text.strip(): prompt_text = "brain" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Combine all masks mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) mask_arrays.append(mask_np) if len(mask_arrays) > 0: final_mask = np.any(mask_arrays, axis=0) # Visualize plt.figure(figsize=(12, 6)) # Original plt.subplot(1, 2, 1) plt.imshow(Image.fromarray(img_uint8_3ch[:, :, 0] if len(img_uint8_3ch.shape) == 3 else img_uint8_3ch), cmap='gray') plt.title(f"Original ({original_size[0]}x{original_size[1]})", fontsize=10) plt.axis('off') # Processed with mask plt.subplot(1, 2, 2) plt.imshow(pil_image) if final_mask is not None: # Resize mask to match processed image mask_h, mask_w = final_mask.shape img_h, img_w = img_resized.shape[:2] if (mask_h, mask_w) != (img_h, img_w): final_mask = np.array( Image.fromarray(final_mask.astype(np.uint8) * 255).resize((img_w, img_h)) ) > 127 plt.imshow(final_mask, alpha=transparency, cmap=colormap) plt.title(f"Transformed ({target_size}px) + Segmentation", fontsize=10) plt.axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=100) plt.close() status = f"✅ Processed with ResizeLongestSide({target_size})" if apply_clahe: status += f" + CLAHE(clip={clahe_clip})" return output_path, status except Exception as e: print(f"Error in advanced transforms: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def edge_based_segmentation(image_file, modality, window_type, edge_threshold=50, dilation_size=3, colormap='spring', transparency=0.5): """ Edge-based automatic segmentation using Sobel/Canny-like detection. Useful for finding boundaries in medical images. """ if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) else: img = Image.open(file_path).convert('L') img_uint8 = np.array(img) # Compute Sobel edges from scipy.ndimage import sobel, binary_dilation, binary_fill_holes # Sobel edge detection dx = sobel(img_uint8.astype(float), axis=1) dy = sobel(img_uint8.astype(float), axis=0) edges = np.hypot(dx, dy) # Threshold edges edge_mask = edges > edge_threshold # Dilate edges to connect nearby components if dilation_size > 0: struct = np.ones((dilation_size, dilation_size)) edge_mask = binary_dilation(edge_mask, structure=struct) # Fill holes filled_mask = binary_fill_holes(edge_mask) # Label connected components labeled, num_features = ndimage.label(filled_mask) # Create RGB image for display pil_image = Image.fromarray(img_uint8).convert('RGB') # Visualize plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.imshow(img_uint8, cmap='gray') plt.title("Original Image", fontsize=10) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(edges, cmap='hot') plt.title(f"Edge Detection (threshold={edge_threshold})", fontsize=10) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(pil_image) plt.imshow(filled_mask, alpha=transparency, cmap=colormap) plt.title(f"Segmentation ({num_features} regions)", fontsize=10) plt.axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=100) plt.close() return output_path, f"✅ Edge-based segmentation complete! Found {num_features} regions." except Exception as e: print(f"Error in edge segmentation: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def save_last_annotation(): """Save the last processed annotation.""" if last_processed_mask["mask"] is None: return None, "⚠️ No annotation to save. Process an image first." stats = calculate_roi_statistics( last_processed_mask["image_file"], last_processed_mask["mask"], last_processed_mask["modality"] ) return save_annotation( last_processed_mask["image_file"], last_processed_mask["mask"], last_processed_mask["prompt"], last_processed_mask["modality"], stats ) # Create Gradio Interface demo_file_path = demo_dicom_path if demo_file_available and os.path.exists(demo_dicom_path) else None def load_demo_file(): """Load the demo DICOM file.""" if demo_file_path and os.path.exists(demo_file_path): return demo_file_path, f"✅ Demo file loaded: {demo_file_path}\nReady to segment!" else: return None, "⚠️ Demo file not found. Please upload a medical image file (DICOM, PNG, or JPG)." def process_with_status(image_file, prompt_text, modality, window_type): """Wrapper function to update status during processing.""" if model is None or processor is None: return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file (DICOM, PNG, or JPG) or load the demo file." result = process_medical_image(image_file, prompt_text, modality, window_type) if result is None: return None, "❌ Processing failed. Check console for error details." else: return result, "✅ Segmentation complete!" def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type): """Process image and compare with ground truth segmentation mask.""" if model is None or processor is None: return None, None, 0.0, 0.0, "❌ Error: Model not loaded." if image_file is None: return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file." if gt_mask_file is None: return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file." # Process image and get mask result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result is None or pred_mask is None: return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details." # Compare with ground truth comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file) if comparison_path: status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}" return result, comparison_path, dice_score, iou_score, status else: return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed." def process_sequence(image_files, prompt_text, modality, window_type): """Process multiple images from the same subject and return gallery of results.""" if model is None or processor is None: return [], "❌ Error: Model not loaded." if not image_files: return [], "⚠️ Please upload medical image files (DICOM, PNG, or JPG)." # Handle single file or list of files if isinstance(image_files, str): image_files = [image_files] results = [] status_messages = [] for idx, image_file in enumerate(image_files): if image_file is None: continue status_msg = f"Processing image {idx + 1}/{len(image_files)}..." status_messages.append(status_msg) result = process_medical_image(image_file, prompt_text, modality, window_type) if result: results.append(result) status_messages.append(f"✅ Image {idx + 1} segmented successfully") else: status_messages.append(f"❌ Failed to process image {idx + 1}") if results: status = f"✅ Processed {len(results)}/{len(image_files)} images successfully!\n" + "\n".join(status_messages) return results, status else: return [], "❌ No images were processed successfully. Check console for error details." # Store processed results for interactive viewer processed_results_cache = {} def extract_subject_id(file_path): """Extract subject/patient ID from file path. Common patterns: - Folder name: /subject_001/image.png -> subject_001 - Filename prefix: subject_001_slice_01.png -> subject_001 - Patient ID in filename: patient_123_slice_5.dcm -> patient_123 - Study UID in DICOM: extract from DICOM metadata Returns: tuple: (subject_id, confidence_level, source) confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback) source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback' """ import re file_path = str(file_path) filename = os.path.basename(file_path) dir_path = os.path.dirname(file_path) # HIGHEST CONFIDENCE: DICOM metadata (most reliable) if file_path.lower().endswith('.dcm'): try: ds = pydicom.dcmread(file_path, stop_before_pixels=True) patient_id = getattr(ds, 'PatientID', None) if patient_id and patient_id.strip(): return f"patient_{patient_id}", 'high', 'dicom_patientid' study_uid = getattr(ds, 'StudyInstanceUID', None) if study_uid: # Use full study UID as identifier (unique per study) return f"study_{study_uid}", 'high', 'dicom_study' except: pass # MEDIUM CONFIDENCE: Folder name (common in medical datasets) folder_name = os.path.basename(dir_path.rstrip('/')) if folder_name and folder_name not in ['', '.', '..']: # Check if folder name looks like a subject ID if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I): return folder_name, 'medium', 'folder' # MEDIUM CONFIDENCE: Filename pattern patterns = [ (r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'), # subject_001, patient_123 (r'([A-Z]{2,}\d+)', 'medium'), # BR001, MR123, etc. ] for pattern, confidence in patterns: match = re.search(pattern, filename, re.I) if match: if len(match.groups()) > 1: return f"{match.group(1)}_{match.group(2)}", confidence, 'filename' else: return match.group(1), confidence, 'filename' # LOW CONFIDENCE: Numeric pattern (could be slice number, not patient ID) numeric_match = re.search(r'(\d{3,})', filename) if numeric_match: return numeric_match.group(1), 'low', 'filename_numeric' # LOWEST CONFIDENCE: Fallback to filename base_name = os.path.splitext(filename)[0] if len(base_name) > 0: return base_name, 'low', 'fallback' return "unknown", 'low', 'unknown' def group_images_by_subject(image_files): """Group image files by subject/patient ID. Returns: dict: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}} """ if not image_files: return {} if isinstance(image_files, str): image_files = [image_files] # Filter out None files image_files = [f for f in image_files if f is not None] # Group by subject ID and track confidence subject_groups = {} for file_path in image_files: subject_id, confidence, source = extract_subject_id(file_path) if subject_id not in subject_groups: subject_groups[subject_id] = { 'files': [], 'confidence': confidence, 'sources': set([source]) } subject_groups[subject_id]['files'].append(file_path) subject_groups[subject_id]['sources'].add(source) # Upgrade confidence if we find high-confidence source if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'): subject_groups[subject_id]['confidence'] = confidence # Sort files within each group (by filename) for subject_id in subject_groups: subject_groups[subject_id]['files'].sort() subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources']) return subject_groups def detect_subjects(image_files): """Detect and return subject groups from uploaded files.""" if not image_files: return gr.Dropdown(choices=[], value=None), "No files uploaded" subject_groups = group_images_by_subject(image_files) if not subject_groups: return gr.Dropdown(choices=[], value=None), "No subjects detected" choices = [] status_msg = f"✅ Detected {len(subject_groups)} subject(s):\n\n" for subject_id, info in sorted(subject_groups.items()): num_files = len(info['files']) confidence = info['confidence'] sources = ', '.join(info['sources']) # Add confidence indicator if confidence == 'high': confidence_icon = "✅" confidence_text = "HIGH (DICOM metadata - very reliable)" elif confidence == 'medium': confidence_icon = "⚠️" confidence_text = "MEDIUM (folder/filename pattern - likely same patient)" else: confidence_icon = "⚠️⚠️" confidence_text = "LOW (filename-based - verify manually)" choices.append(f"{subject_id} ({num_files} slices)") status_msg += f"{confidence_icon} **{subject_id}**: {num_files} slices\n" status_msg += f" Confidence: {confidence_text}\n" status_msg += f" Source: {sources}\n\n" # Add warning for low confidence low_confidence_count = sum(1 for info in subject_groups.values() if info['confidence'] == 'low') if low_confidence_count > 0: status_msg += f"⚠️ **Warning**: {low_confidence_count} subject(s) detected with LOW confidence.\n" status_msg += "Please verify these are actually the same patient before proceeding.\n" return gr.Dropdown(choices=choices, value=choices[0] if choices else None), status_msg def process_slices_for_viewer(image_files, selected_subject, prompt_text, modality, window_type): """Process all slices for selected subject and cache results for interactive viewing.""" if model is None or processor is None: return None, 0, "❌ Error: Model not loaded.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" if not image_files: return None, 0, "⚠️ Please upload medical image files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Group by subject subject_groups = group_images_by_subject(image_files) if not subject_groups: return None, 0, "⚠️ Could not detect subjects in uploaded files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Extract subject ID from selection (format: "subject_id (N slices)") if selected_subject: subject_id = selected_subject.split(" (")[0] else: # Use first subject if none selected subject_id = list(subject_groups.keys())[0] if subject_id not in subject_groups: return None, 0, f"⚠️ Subject '{subject_id}' not found.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Get files for selected subject subject_info = subject_groups[subject_id] subject_files = subject_info['files'] confidence = subject_info['confidence'] # Add confidence warning confidence_warning = "" if confidence == 'low': confidence_warning = "\n⚠️ LOW CONFIDENCE: These files may not be from the same patient. Please verify!" elif confidence == 'medium': confidence_warning = "\n⚠️ MEDIUM CONFIDENCE: Likely same patient, but verify if critical." results = [] status_messages = [] for idx, image_file in enumerate(subject_files): status_msg = f"Processing slice {idx + 1}/{len(subject_files)}..." status_messages.append(status_msg) result = process_medical_image(image_file, prompt_text, modality, window_type) if result: results.append(result) status_messages.append(f"✅ Slice {idx + 1} processed") else: status_messages.append(f"❌ Failed to process slice {idx + 1}") if results: # Cache results with a unique key including subject ID cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" processed_results_cache[cache_key] = results max_slices = len(results) - 1 status = f"✅ Processed {len(results)}/{len(subject_files)} slices for {subject_id}!\nUse slider or buttons to navigate.{confidence_warning}" slice_info = f"Slice 1/{len(results)} ({subject_id})" # Update subject dropdown choices choices = [] for sid, info in sorted(subject_groups.items()): marker = "✓" if sid == subject_id else "" num_files = len(info['files']) choices.append(f"{marker} {sid} ({num_files} slices)") return results[0], max_slices, status, slice_info, gr.Dropdown(choices=choices, value=choices[0] if choices else None), f"Viewing: {subject_id}" else: return None, 0, "❌ No slices were processed successfully.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" def navigate_slice(slice_idx, image_files, selected_subject, prompt_text, modality, window_type): """Navigate to a specific slice in the sequence.""" if not image_files: return None, "No slices loaded" # Group by subject and get selected subject's files subject_groups = group_images_by_subject(image_files) if selected_subject: subject_id = selected_subject.split(" (")[0] else: subject_id = list(subject_groups.keys())[0] if subject_groups else None if not subject_id or subject_id not in subject_groups: return None, "No slices loaded" subject_info = subject_groups[subject_id] subject_files = subject_info['files'] slice_idx = int(slice_idx) cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" if cache_key in processed_results_cache: results = processed_results_cache[cache_key] if 0 <= slice_idx < len(results): slice_info = f"Slice {slice_idx + 1}/{len(results)} ({subject_id})" return results[slice_idx], slice_info # If not cached, process on the fly (fallback) if 0 <= slice_idx < len(subject_files): result = process_medical_image(subject_files[slice_idx], prompt_text, modality, window_type) if result: slice_info = f"Slice {slice_idx + 1}/{len(subject_files)} ({subject_id})" return result, slice_info return None, f"Invalid slice index: {slice_idx}" with gr.Blocks() as demo: gr.Markdown("# 🏥 NeuroSAM 3: Medical Image Segmentation") demo_info = "" if demo_file_path: demo_info = f"\n\n**📁 Demo File Available:** A sample DICOM file is ready: `{demo_file_path}`\nClick 'Load Demo File' button below to use it!" gr.Markdown(f""" Upload a medical image (DICOM .dcm, PNG, or JPG) and type what you want to find (e.g., 'brain', 'tumor', 'skull'). {demo_info} **Instructions:** 1. Upload a medical image file: - DICOM (.dcm) files for CT or MRI scans - PNG/JPG files (e.g., from Kaggle datasets like brain MRI images) 2. Enter a text prompt describing what to segment 3. Select the imaging modality (CT or MRI) 4. Choose the windowing strategy (for CT images) 5. Click "Segment Structure" to process **Supported Formats:** - DICOM (.dcm) - Standard medical imaging format - PNG/JPG - Standard image formats (works with Kaggle brain MRI datasets) """) with gr.Tabs(): with gr.Tab("Single Image"): with gr.Row(): with gr.Column(): file_input = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath", value=demo_file_path ) load_demo_btn = gr.Button( "📁 Load Demo File", variant="secondary", size="sm", visible=bool(demo_file_path) ) text_input = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) submit_btn = gr.Button("Segment Structure", variant="primary", size="lg") with gr.Column(): image_output = gr.Image( label="Segmentation Result", type="filepath" ) gr.Markdown("### Status") status_text = gr.Textbox( label="Processing Status", value="Ready. Upload a medical image file (DICOM, PNG, or JPG) to begin.", interactive=False ) with gr.Tab("Interactive Slice Viewer"): gr.Markdown("**Scroll through multiple slices/images from the same subject interactively**") gr.Markdown(""" **📋 Subject Detection:** The app automatically detects subject/patient IDs from: - Folder names (e.g., `subject_001/`, `patient_123/`) - Filenames (e.g., `subject_001_slice_01.png`, `patient_123.dcm`) - DICOM metadata (PatientID, StudyInstanceUID) **💡 Tip:** Upload images organized by subject folders for best results! """) with gr.Row(): with gr.Column(): files_input = gr.File( label="Upload Multiple Images/Slices (Select multiple files)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) subject_dropdown = gr.Dropdown( label="Select Subject/Patient", choices=[], value=None, interactive=True, info="Select which subject's slices to view (auto-detected from filenames/folders)" ) text_input_batch = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown_batch = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown_batch = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) detect_subjects_btn = gr.Button("🔍 Detect Subjects", variant="secondary", size="sm") submit_batch_btn = gr.Button("Process All Slices", variant="primary", size="lg") gr.Markdown("---") gr.Markdown("### 🎛️ Slice Navigator") slice_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Slice Number", info="Use slider or arrow keys to navigate through slices", interactive=False ) with gr.Row(): prev_btn = gr.Button("⬆️ Previous Slice", size="sm") next_btn = gr.Button("⬇️ Next Slice", size="sm") auto_play_btn = gr.Button("▶️ Auto-play", size="sm") with gr.Column(): current_slice_output = gr.Image( label="Current Slice Segmentation", type="filepath", height=600 ) gr.Markdown("### Slice Info") slice_info_text = gr.Textbox( label="Current Slice", value="No slices loaded", interactive=False ) subject_info_text = gr.Textbox( label="Subject Info", value="", interactive=False, visible=False ) gr.Markdown("### Status") status_batch_text = gr.Textbox( label="Processing Status", value="Ready. Upload multiple medical image files to process a sequence.", interactive=False, lines=4 ) with gr.Tab("Gallery View"): gr.Markdown("**View all segmentations in a gallery grid**") with gr.Row(): with gr.Column(): files_input_gallery = gr.File( label="Upload Multiple Images (Select multiple files)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) text_input_gallery = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes" ) with gr.Row(): modality_dropdown_gallery = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI" ) window_dropdown_gallery = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)" ) submit_gallery_btn = gr.Button("Process & Show Gallery", variant="primary", size="lg") with gr.Column(): gallery_output = gr.Gallery( label="Segmentation Gallery", show_label=True, elem_id="gallery", columns=2, rows=2, height="auto" ) status_gallery_text = gr.Textbox( label="Status", value="Ready. Upload multiple images to view in gallery.", interactive=False ) with gr.Tab("Compare with Ground Truth"): gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**") with gr.Row(): with gr.Column(): file_input_gt = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gt_mask_input = gr.File( label="Upload Ground Truth Mask (PNG, JPG)", file_types=[".png", ".jpg", ".jpeg"], type="filepath" ) text_input_gt = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown_gt = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown_gt = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg") with gr.Column(): image_output_gt = gr.Image( label="SAM 3 Segmentation", type="filepath" ) comparison_output = gr.Image( label="Comparison: SAM 3 vs Ground Truth", type="filepath" ) gr.Markdown("### Metrics") dice_score_text = gr.Textbox( label="Dice Score", value="--", interactive=False ) iou_score_text = gr.Textbox( label="IoU Score", value="--", interactive=False ) gr.Markdown("### Status") status_gt_text = gr.Textbox( label="Processing Status", value="Ready. Upload image and ground truth mask to compare.", interactive=False, lines=3 ) # NEW ENHANCED TABS with gr.Tab("✨ Enhanced Single Image"): gr.Markdown("**Enhanced version with image adjustments, visualization options, and progress tracking**") with gr.Row(): with gr.Column(): file_input_enh = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath", value=demo_file_path ) text_input_enh = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_enh = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_enh = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) # NEW: Image adjustment controls gr.Markdown("### 🎨 Image Adjustments") brightness_slider = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Brightness", info="Adjust image brightness (0.5 = darker, 2.0 = brighter)" ) contrast_slider = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Contrast", info="Adjust image contrast (0.5 = lower contrast, 2.0 = higher contrast)" ) # NEW: Visualization options gr.Markdown("### 🎭 Visualization Options") colormap_dropdown = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], label="Mask Colormap", value="spring", info="Color scheme for segmentation overlay" ) transparency_slider = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Mask Transparency", info="Transparency of segmentation overlay (0.0 = opaque, 1.0 = transparent)" ) submit_enh_btn = gr.Button("Segment with Progress", variant="primary", size="lg") with gr.Column(): # Progress indicator progress_text = gr.Textbox( label="Progress", value="Ready. Upload a medical image file to begin.", interactive=False ) image_output_enh = gr.Image( label="Segmentation Result", type="filepath" ) # NEW: Download button download_output = gr.File( label="Download Result", visible=True ) # NEW: Metrics display gr.Markdown("### 📊 Metrics") metrics_text = gr.Textbox( label="Segmentation Info", value="", lines=3, interactive=False ) with gr.Tab("✨ Enhanced Batch Processing"): gr.Markdown("**Enhanced batch processing with ZIP download and progress tracking**") with gr.Row(): with gr.Column(): files_input_enh_batch = gr.File( label="Upload Multiple Images for Batch Processing", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) text_input_enh_batch = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes" ) with gr.Row(): modality_enh_batch = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI" ) window_enh_batch = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)" ) # Image adjustments gr.Markdown("### 🎨 Image Adjustments") brightness_slider_batch = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Brightness" ) contrast_slider_batch = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Contrast" ) # Visualization options gr.Markdown("### 🎭 Visualization Options") colormap_dropdown_batch = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], label="Mask Colormap", value="spring" ) transparency_slider_batch = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Mask Transparency" ) submit_enh_batch_btn = gr.Button("Process Batch with Progress", variant="primary", size="lg") with gr.Column(): gallery_output_enh = gr.Gallery( label="Segmentation Gallery", show_label=True, elem_id="gallery", columns=2, rows=2, height="auto" ) # Batch download ZIP batch_download_output = gr.File( label="Download All Results (ZIP)", visible=True ) status_enh_batch_text = gr.Textbox( label="Status", value="Ready. Upload multiple images to process in batch.", interactive=False, lines=4 ) # NEW: Point/Box Prompts Tab with gr.Tab("🎯 Point/Box Prompts"): gr.Markdown(""" **Interactive Point and Box-based Segmentation** Use precise point clicks or bounding boxes to guide the segmentation. - **Point Prompt**: Click on the region you want to segment - **Box Prompt**: Define a bounding box around the region of interest """) with gr.Tabs(): with gr.Tab("Point Prompt"): with gr.Row(): with gr.Column(): file_input_point = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gr.Markdown("### Point Coordinates") with gr.Row(): point_x = gr.Number(label="X coordinate", value=128, precision=0) point_y = gr.Number(label="Y coordinate", value=128, precision=0) with gr.Row(): modality_point = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_point = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) with gr.Row(): colormap_point = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_point = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") submit_point_btn = gr.Button("Segment at Point", variant="primary") with gr.Column(): output_point = gr.Image(label="Point Segmentation Result", type="filepath") status_point = gr.Textbox(label="Status", interactive=False) with gr.Tab("Box Prompt"): with gr.Row(): with gr.Column(): file_input_box = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gr.Markdown("### Bounding Box Coordinates") with gr.Row(): box_x1 = gr.Number(label="X1 (left)", value=50, precision=0) box_y1 = gr.Number(label="Y1 (top)", value=50, precision=0) with gr.Row(): box_x2 = gr.Number(label="X2 (right)", value=200, precision=0) box_y2 = gr.Number(label="Y2 (bottom)", value=200, precision=0) with gr.Row(): modality_box = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_box = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) with gr.Row(): colormap_box = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_box = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") submit_box_btn = gr.Button("Segment in Box", variant="primary") with gr.Column(): output_box = gr.Image(label="Box Segmentation Result", type="filepath") status_box = gr.Textbox(label="Status", interactive=False) # NEW: ROI Statistics & Export Tab with gr.Tab("📊 ROI Statistics & Export"): gr.Markdown(""" **ROI Statistics and Export Options** Process an image and get detailed statistics about the segmented region: - Area (pixels and percentage) - Intensity statistics (mean, std, min, max) - Centroid and bounding box - Export to NIFTI format for medical imaging software - Save/Load annotations for later use """) with gr.Row(): with gr.Column(): file_input_stats = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_stats = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_stats = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_stats = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) submit_stats_btn = gr.Button("Process & Calculate Statistics", variant="primary") gr.Markdown("### Export Options") with gr.Row(): export_nifti_btn = gr.Button("📥 Export to NIFTI", size="sm") save_annotation_btn = gr.Button("💾 Save Annotation", size="sm") with gr.Column(): output_stats = gr.Image(label="Segmentation Result", type="filepath") status_stats = gr.Textbox(label="Status", interactive=False) gr.Markdown("### 📊 ROI Statistics") roi_stats_text = gr.Markdown(value="*Process an image to see statistics*") nifti_download = gr.File(label="Download NIFTI", visible=True) annotation_download = gr.File(label="Download Annotation", visible=True) gr.Markdown("---") gr.Markdown("### Load Saved Annotation") with gr.Row(): with gr.Column(): annotation_upload = gr.File( label="Upload Annotation (.zip)", file_types=[".zip"], type="filepath" ) original_image_upload = gr.File( label="Upload Original Image (for visualization)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) load_annotation_btn = gr.Button("Load & Visualize Annotation", variant="secondary") with gr.Column(): loaded_annotation_output = gr.Image(label="Loaded Annotation", type="filepath") loaded_annotation_info = gr.Markdown(value="*Upload an annotation file to load*") # NEW: Multi-Mask Output Tab with gr.Tab("🎭 Multi-Mask Output"): gr.Markdown(""" **Generate Multiple Mask Candidates** SAM can generate multiple segmentation hypotheses with confidence scores. This is useful when the segmentation is ambiguous or you want to compare alternatives. """) with gr.Row(): with gr.Column(): file_input_multi = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_multi = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_multi = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_multi = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) num_masks_slider = gr.Slider(1, 5, value=3, step=1, label="Number of Masks") submit_multi_btn = gr.Button("Generate Multiple Masks", variant="primary") with gr.Column(): gallery_multi = gr.Gallery( label="Mask Candidates", show_label=True, columns=2, rows=2, height="auto" ) status_multi = gr.Textbox(label="Status", interactive=False) mask_info_multi = gr.Textbox(label="Mask Information", lines=5, interactive=False) # NEW: SAM-Medical-Imaging Inspired Tabs with gr.Tab("🔬 Automatic Mask Generator"): gr.Markdown(""" **Automatic Mask Generator (AMG)** Inspired by [SAM-Medical-Imaging](https://github.com/amine0110/SAM-Medical-Imaging) - automatically segment the entire image without needing specific prompts. - Uses multiple prompts internally to discover all regions - Filters and deduplicates overlapping masks - Great for exploratory analysis of medical images """) with gr.Row(): with gr.Column(): file_input_amg = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) with gr.Row(): modality_amg = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_amg = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### AMG Parameters") points_per_side = gr.Slider( 8, 32, value=16, step=4, label="Grid Density", info="Higher = more detailed but slower" ) min_mask_area = gr.Slider( 50, 1000, value=100, step=50, label="Minimum Mask Area (pixels)", info="Filter out small noise regions" ) colormap_amg = gr.Dropdown( ["tab20", "tab10", "Set1", "Set2", "Paired", "rainbow"], label="Colormap for Regions", value="tab20" ) submit_amg_btn = gr.Button("🔬 Run Automatic Segmentation", variant="primary", size="lg") with gr.Column(): output_amg = gr.Image(label="AMG Result", type="filepath") status_amg = gr.Textbox(label="Status", interactive=False) info_amg = gr.Markdown(value="*Run AMG to see region details*") with gr.Tab("🔧 Advanced Transforms"): gr.Markdown(""" **Advanced Image Transforms** Apply preprocessing transforms from SAM-Medical-Imaging before segmentation: - **ResizeLongestSide**: Resize maintaining aspect ratio (optimal for SAM) - **CLAHE-like Enhancement**: Enhance local contrast for better visibility """) with gr.Row(): with gr.Column(): file_input_transform = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_transform = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_transform = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_transform = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### Transform Settings") target_size_slider = gr.Slider( 256, 2048, value=1024, step=128, label="Target Size (ResizeLongestSide)", info="Resize image's longest side to this value" ) apply_clahe_checkbox = gr.Checkbox( label="Apply CLAHE-like Enhancement", value=False, info="Enhance local contrast (useful for low-contrast images)" ) clahe_clip_slider = gr.Slider( 1.0, 4.0, value=2.0, step=0.5, label="CLAHE Clip Limit", info="Higher = more contrast enhancement" ) with gr.Row(): colormap_transform = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_transform = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Transparency" ) submit_transform_btn = gr.Button("🔧 Process with Transforms", variant="primary") with gr.Column(): output_transform = gr.Image(label="Transform + Segmentation Result", type="filepath") status_transform = gr.Textbox(label="Status", interactive=False) with gr.Tab("🌊 Edge-Based Segmentation"): gr.Markdown(""" **Edge-Based Automatic Segmentation** Uses Sobel edge detection to find boundaries in medical images. Works without AI model - useful for comparison or when model fails. - Detects edges using gradient-based methods - Fills regions within boundaries - Identifies distinct anatomical structures """) with gr.Row(): with gr.Column(): file_input_edge = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) with gr.Row(): modality_edge = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_edge = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### Edge Detection Parameters") edge_threshold_slider = gr.Slider( 10, 150, value=50, step=10, label="Edge Threshold", info="Lower = more edges detected" ) dilation_size_slider = gr.Slider( 0, 10, value=3, step=1, label="Dilation Size", info="Connect nearby edges (0 = no dilation)" ) with gr.Row(): colormap_edge = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_edge = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Transparency" ) submit_edge_btn = gr.Button("🌊 Run Edge Segmentation", variant="primary") with gr.Column(): output_edge = gr.Image(label="Edge Segmentation Result", type="filepath") status_edge = gr.Textbox(label="Status", interactive=False) # Single image processing load_demo_btn.click( fn=load_demo_file, inputs=[], outputs=[file_input, status_text] ) submit_btn.click( fn=process_with_status, inputs=[file_input, text_input, modality_dropdown, window_dropdown], outputs=[image_output, status_text] ) # Detect subjects when files are uploaded detect_subjects_btn.click( fn=detect_subjects, inputs=[files_input], outputs=[subject_dropdown, status_batch_text] ) # Interactive slice viewer submit_batch_btn.click( fn=process_slices_for_viewer, inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_slider, status_batch_text, slice_info_text, subject_dropdown, subject_info_text] ).then( lambda max_val: gr.Slider(maximum=max(max_val, 1), interactive=True), inputs=[slice_slider], outputs=[slice_slider] ) def update_slice(slice_num, files, selected_subject, prompt, mod, window): result, info = navigate_slice(int(slice_num), files, selected_subject, prompt, mod, window) return result, info slice_slider.change( fn=update_slice, inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_info_text] ) def prev_slice(current, files, selected_subject, prompt, mod, window): new_val = max(0, current - 1) result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) return new_val, result, info def next_slice(current, max_val, files, selected_subject, prompt, mod, window): new_val = min(max_val, current + 1) result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) return new_val, result, info prev_btn.click( fn=prev_slice, inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[slice_slider, current_slice_output, slice_info_text] ) next_btn.click( fn=next_slice, inputs=[slice_slider, slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[slice_slider, current_slice_output, slice_info_text] ) # Gallery view submit_gallery_btn.click( fn=process_sequence, inputs=[files_input_gallery, text_input_gallery, modality_dropdown_gallery, window_dropdown_gallery], outputs=[gallery_output, status_gallery_text] ) # Ground truth comparison submit_gt_btn.click( fn=process_with_ground_truth, inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt], outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text] ) # Enhanced single image processing def process_enhanced_wrapper(image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency, progress=gr.Progress()): """Wrapper to return both image and download file.""" result, status, metrics = process_with_progress( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency, progress ) download_file = result if result else None return result, status, metrics, download_file submit_enh_btn.click( fn=process_enhanced_wrapper, inputs=[ file_input_enh, text_input_enh, modality_enh, window_enh, brightness_slider, contrast_slider, colormap_dropdown, transparency_slider ], outputs=[image_output_enh, progress_text, metrics_text, download_output] ) # Enhanced batch processing submit_enh_batch_btn.click( fn=process_batch_enhanced, inputs=[ files_input_enh_batch, text_input_enh_batch, modality_enh_batch, window_enh_batch, brightness_slider_batch, contrast_slider_batch, colormap_dropdown_batch, transparency_slider_batch ], outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text] ) # Point prompt processing submit_point_btn.click( fn=process_with_point_prompt, inputs=[file_input_point, point_x, point_y, modality_point, window_point, colormap_point, transparency_point], outputs=[output_point, status_point] ) # Box prompt processing submit_box_btn.click( fn=process_with_box_prompt, inputs=[file_input_box, box_x1, box_y1, box_x2, box_y2, modality_box, window_box, colormap_box, transparency_box], outputs=[output_box, status_box] ) # ROI Statistics processing submit_stats_btn.click( fn=process_and_store_mask, inputs=[file_input_stats, text_input_stats, modality_stats, window_stats], outputs=[output_stats, status_stats, roi_stats_text] ) # NIFTI Export export_nifti_btn.click( fn=export_last_mask_nifti, inputs=[], outputs=[nifti_download, status_stats] ) # Save Annotation save_annotation_btn.click( fn=save_last_annotation, inputs=[], outputs=[annotation_download, status_stats] ) # Load Annotation load_annotation_btn.click( fn=visualize_loaded_annotation, inputs=[original_image_upload, annotation_upload], outputs=[loaded_annotation_output, loaded_annotation_info] ) # Multi-Mask processing submit_multi_btn.click( fn=process_multi_mask, inputs=[file_input_multi, text_input_multi, modality_multi, window_multi, num_masks_slider], outputs=[gallery_multi, status_multi, mask_info_multi] ) # Auto-play functionality for slice viewer def auto_play_slices(files, selected_subject, prompt, mod, window): """Auto-play through slices with a short delay.""" if not files: yield None, "No slices loaded", 0 return subject_groups = group_images_by_subject(files) if selected_subject: subject_id = selected_subject.split(" (")[0] else: subject_id = list(subject_groups.keys())[0] if subject_groups else None if not subject_id or subject_id not in subject_groups: yield None, "No slices loaded", 0 return subject_files = subject_groups[subject_id]['files'] cache_key = f"{subject_id}_{len(subject_files)}_{prompt}_{mod}" if cache_key not in processed_results_cache: yield None, "Please process slices first", 0 return results = processed_results_cache[cache_key] for idx in range(len(results)): slice_info = f"Slice {idx + 1}/{len(results)} ({subject_id}) - Auto-playing..." yield results[idx], slice_info, idx time.sleep(0.5) # 500ms delay between slices auto_play_btn.click( fn=auto_play_slices, inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_info_text, slice_slider] ) # SAM-Medical-Imaging Inspired Features # Automatic Mask Generator submit_amg_btn.click( fn=automatic_mask_generator, inputs=[file_input_amg, modality_amg, window_amg, points_per_side, min_mask_area, colormap_amg], outputs=[output_amg, status_amg, info_amg] ) # Advanced Transforms submit_transform_btn.click( fn=process_with_advanced_transforms, inputs=[file_input_transform, text_input_transform, modality_transform, window_transform, target_size_slider, apply_clahe_checkbox, clahe_clip_slider, colormap_transform, transparency_transform], outputs=[output_transform, status_transform] ) # Edge-Based Segmentation submit_edge_btn.click( fn=edge_based_segmentation, inputs=[file_input_edge, modality_edge, window_edge, edge_threshold_slider, dilation_size_slider, colormap_edge, transparency_edge], outputs=[output_edge, status_edge] ) if __name__ == "__main__": # Verify model is loaded before launching if model is None or processor is None: print("⚠️ WARNING: SAM 3 model failed to load!") print("⚠️ The app will start but segmentation features will not work.") print("⚠️ Please check:") print(" 1. HF_TOKEN environment variable is set correctly") print(" 2. transformers>=4.45.0 is installed") print(" 3. Sufficient memory/GPU available") else: print("✅ SAM 3 model ready - app starting...") demo.launch(server_name="0.0.0.0", server_port=7860)