""" NeuroSAM 3: Medical Image Segmentation App A Gradio app for segmenting medical images (CT/MRI) using SAM 3 """ from typing import Optional, Tuple, List, Dict, Any, Union import os import tempfile import zipfile import io import json import time from datetime import datetime import gradio as gr import spaces import torch import pydicom import numpy as np from PIL import Image, ImageEnhance, ImageDraw import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from scipy import ndimage from huggingface_hub import login # Import custom modules from config import ( DEMO_DICOM_PATH, DEFAULT_THRESHOLD, DEFAULT_MASK_THRESHOLD, DEFAULT_COLORMAP, DEFAULT_TRANSPARENCY, DEFAULT_BRIGHTNESS, DEFAULT_CONTRAST, OUTPUT_DPI, NIFTI_DEFAULT_NAME, ) from logger_config import logger from models import initialize_model, is_model_loaded, get_model, get_processor, run_sam3_inference from dicom_utils import ( is_dicom_file, process_dicom_to_pil, process_standard_image_to_pil, ) from validators import ( validate_image_file, validate_prompt_text, validate_modality, validate_threshold, validate_mask_threshold, validate_coordinates, validate_bounding_box, validate_num_masks, validate_transparency, validate_brightness_contrast, ValidationError, ) from cache_manager import processed_results_cache from utils import ( extract_subject_id, group_images_by_subject, combine_masks, create_output_image, create_demo_dicom_file, ) from segmentation import ( compare_with_ground_truth, calculate_roi_statistics, format_roi_statistics, generate_grid_points, calculate_dice_score, calculate_iou_score, ) # Try to import nibabel for NIFTI support (optional) try: import nibabel as nib NIBABEL_AVAILABLE = True except ImportError: NIBABEL_AVAILABLE = False logger.warning("nibabel not available - NIFTI export disabled") # Initialize Hugging Face login from config import HF_TOKEN if HF_TOKEN: try: login(token=HF_TOKEN, add_to_git_credential=False) logger.info("Logged in to Hugging Face Hub") except Exception as e: logger.warning(f"Could not login to HF Hub (non-critical): {e}") else: logger.warning("HF_TOKEN not set - some features may not work") # Initialize SAM 3 Model logger.info("Loading SAM 3 Model...") model_loaded = initialize_model() if not model_loaded: logger.warning("SAM 3 model failed to load - segmentation features will be disabled") # Get model and processor references model = get_model() processor = get_processor() # Create Sample DICOM File for Demo demo_file_available = create_demo_dicom_file(DEMO_DICOM_PATH) # compare_with_ground_truth is now imported from segmentation module def process_medical_image( image_file: Optional[str], prompt_text: Optional[str], modality: str, window_type: str, return_mask: bool = False ) -> Optional[Union[str, Tuple[str, Optional[np.ndarray]]]]: """ Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3. Args: image_file: Path to image file prompt_text: Text prompt for segmentation modality: CT or MRI window_type: Windowing strategy return_mask: If True, also return the binary mask array Returns: Path to output image, and optionally the mask array """ if not is_model_loaded(): logger.error("Model not loaded") return None if image_file is None: return None # Validate inputs is_valid, error = validate_image_file(image_file) if not is_valid: logger.error(f"Invalid image file: {error}") return None is_valid, error = validate_modality(modality) if not is_valid: logger.error(f"Invalid modality: {error}") return None is_valid, error, prompt_text = validate_prompt_text(prompt_text) if not is_valid: logger.error(f"Invalid prompt: {error}") return None try: file_path = str(image_file) # Process image based on type if is_dicom_file(file_path): pil_image = process_dicom_to_pil(file_path, modality, window_type) else: pil_image = process_standard_image_to_pil(file_path, modality, window_type) # Run SAM 3 Inference results = run_sam3_inference( pil_image, prompt_text, threshold=DEFAULT_THRESHOLD, mask_threshold=DEFAULT_MASK_THRESHOLD ) if results is None: logger.warning("SAM 3 inference returned None") return None # Extract and combine masks final_mask = None if 'masks' in results and results['masks'] is not None: masks = results['masks'] if len(masks) > 0: final_mask = combine_masks(masks) if final_mask is None: logger.warning("No valid masks found after combining") else: logger.warning("No masks in results") else: logger.warning("No masks in results") # Create output visualization output_path = create_output_image( pil_image, final_mask, prompt_text, colormap=DEFAULT_COLORMAP, transparency=DEFAULT_TRANSPARENCY ) if return_mask: return output_path, final_mask return output_path except pydicom.errors.InvalidDicomError as e: logger.error(f"Invalid DICOM file format: {e}", exc_info=True) return None except Exception as e: logger.error(f"Error processing image: {e}", exc_info=True) return None def process_medical_image_enhanced(image_file, prompt_text, modality, window_type, brightness=1.0, contrast=1.0, colormap='spring', transparency=0.5, return_mask=False): """Enhanced version of process_medical_image with image adjustments and visualization options. Args: image_file: Path to image file prompt_text: Text prompt for segmentation modality: CT or MRI window_type: Windowing strategy brightness: Brightness multiplier (0.5-2.0) contrast: Contrast multiplier (0.5-2.0) colormap: Matplotlib colormap name transparency: Mask overlay transparency (0.0-1.0) return_mask: If True, also return the binary mask array Returns: Path to output image, and optionally the mask array """ if not is_model_loaded(): logger.error("Model not loaded") return None if image_file is None: return None # Validate and sanitize prompt is_valid, error, prompt_text = validate_prompt_text(prompt_text) if not is_valid: logger.error(f"Invalid prompt: {error}") return None try: file_path = str(image_file) # Validate file is_valid, error = validate_image_file(file_path) if not is_valid: logger.error(f"Invalid image file: {error}") return None # Detect file type file_ext = os.path.splitext(file_path)[1].lower() is_dicom = file_ext == '.dcm' if is_dicom: # Process DICOM file ds = pydicom.dcmread(file_path) if not hasattr(ds, 'pixel_array'): logger.error("DICOM file does not contain pixel data") return None raw = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_hu = raw * slope + intercept # Apply Windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI img_min = np.percentile(img_hu, 1) img_max = np.percentile(img_hu, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_hu) img_max = np.max(img_hu) img_range = img_max - img_min if img_range <= 0: return None img_windowed = (img_hu - img_min) / img_range img_windowed = np.clip(img_windowed, 0, 1) img_uint8 = (img_windowed * 255).astype(np.uint8) if len(img_uint8.shape) == 2: pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.fromarray(img_uint8) else: # Process standard image file (PNG, JPG, etc.) pil_image = Image.open(file_path) # Convert to RGB if needed if pil_image.mode != 'RGB': pil_image = pil_image.convert('RGB') # Convert to numpy for normalization img_array = np.array(pil_image) # Handle grayscale images if len(img_array.shape) == 2: img_array = np.stack([img_array] * 3, axis=-1) # Normalize image (percentile-based for MRI-like processing) img_float = img_array.astype(np.float32) if modality == "CT": # For CT-like processing, use windowing if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: # MRI - use percentile normalization img_min = np.percentile(img_float, 1) img_max = np.percentile(img_float, 99) img_range = img_max - img_min if img_range <= 0: img_min = np.min(img_float) img_max = np.max(img_float) img_range = img_max - img_min if img_range <= 0: return None img_normalized = (img_float - img_min) / img_range img_normalized = np.clip(img_normalized, 0, 1) img_uint8 = (img_normalized * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8.astype(np.uint8)) # Apply brightness and contrast adjustments enhancer = ImageEnhance.Brightness(pil_image) pil_image = enhancer.enhance(brightness) enhancer = ImageEnhance.Contrast(pil_image) pil_image = enhancer.enhance(contrast) # Run SAM 3 Inference - using helper function matching official implementation # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) if results is None: return None # Draw Masks on Image with enhanced visualization - matching official implementation format plt.figure(figsize=(10, 10)) plt.imshow(pil_image) final_mask = None if 'masks' in results and results['masks'] is not None: masks = results['masks'] # List of mask tensors from post_process_instance_segmentation scores = results.get('scores', []) if len(masks) > 0: # Combine all masks into one mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) mask_arrays.append(mask_np) # Combine all masks if len(mask_arrays) > 0: final_mask = np.any(mask_arrays, axis=0) plt.imshow(final_mask, alpha=transparency, cmap=colormap) else: logger.warning("No valid masks found") else: logger.warning("No masks in results") else: logger.warning("No masks in results") plt.axis('off') plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() if return_mask: return output_path, final_mask return output_path except pydicom.errors.InvalidDicomError as e: logger.error(f"Invalid DICOM file format: {e}", exc_info=True) return None except Exception as e: logger.error(f"Error processing image: {e}", exc_info=True) import traceback traceback.print_exc() return None def process_with_progress( image_file: Optional[str], prompt_text: Optional[str], modality: str, window_type: str, brightness: float = DEFAULT_BRIGHTNESS, contrast: float = DEFAULT_CONTRAST, colormap: str = DEFAULT_COLORMAP, transparency: float = DEFAULT_TRANSPARENCY, progress: Any = gr.Progress() ) -> Tuple[Optional[str], str, str]: """Process with progress indicator.""" if not is_model_loaded(): return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" progress(0, desc="Starting...") progress(0.1, desc="Reading image...") result = process_medical_image_enhanced( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency ) progress(0.8, desc="Processing segmentation...") if result is None: progress(1.0, desc="Failed!") return None, "❌ Processing failed. Check console for error details.", "" progress(1.0, desc="Complete!") # Calculate metrics metrics = f"Prompt: {prompt_text}\nModality: {modality}\nProcessed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" return result, "✅ Segmentation complete!", metrics def create_downloadable_result(output_path, prompt_text): """Create a downloadable file with metadata.""" return output_path # Gradio File component handles downloads automatically def process_batch_enhanced(image_files, prompt_text, modality, window_type, brightness=1.0, contrast=1.0, colormap='spring', transparency=0.5, progress=gr.Progress()): """Process multiple images with enhanced features and create ZIP download.""" if not is_model_loaded(): return [], None, "❌ Error: Model not loaded." if not image_files: return [], None, "⚠️ Please upload medical image files." # Handle single file or list of files if isinstance(image_files, str): image_files = [image_files] results = [] total = len(image_files) for idx, image_file in enumerate(image_files): if image_file is None: continue progress((idx + 1) / total, desc=f"Processing image {idx + 1}/{total}...") result = process_medical_image_enhanced( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency ) if result: results.append(result) if not results: return [], None, "❌ No images were processed successfully." # Create ZIP file with all results zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: for idx, result_path in enumerate(results): if os.path.exists(result_path): filename = f"segmentation_{idx + 1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" zip_file.write(result_path, filename) zip_buffer.seek(0) zip_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') zip_path.write(zip_buffer.read()) zip_path.close() status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download." return results, zip_path.name, status # ============================================================================ # ENHANCED FEATURES - Auto-play, Point/Box Prompts, ROI Stats, NIFTI Export # ============================================================================ # Global state for auto-play auto_play_state = {"running": False, "current_idx": 0} # calculate_roi_statistics is now imported from segmentation module def _calculate_roi_statistics_legacy(image_file, mask, modality): """Calculate ROI statistics from the segmented region. Returns: dict: Statistics including area, mean intensity, std, min, max, centroid """ if mask is None or not isinstance(mask, np.ndarray): return { "error": "No valid mask available", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } try: # Load original image for intensity statistics file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept else: img = Image.open(file_path) if img.mode == 'RGB': img = img.convert('L') # Convert to grayscale for intensity stats img_array = np.array(img).astype(np.float32) # Resize mask if needed if mask.shape != img_array.shape: from scipy.ndimage import zoom zoom_factors = (img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1]) mask = zoom(mask.astype(float), zoom_factors, order=0) > 0.5 # Calculate statistics mask_bool = mask.astype(bool) total_pixels = mask.size roi_pixels = np.sum(mask_bool) if roi_pixels == 0: return { "error": "No pixels in ROI", "area_pixels": 0, "area_percentage": 0, "mean_intensity": 0, "std_intensity": 0, "min_intensity": 0, "max_intensity": 0, "centroid": (0, 0), "bounding_box": (0, 0, 0, 0) } roi_intensities = img_array[mask_bool] # Calculate centroid labeled_mask, num_features = ndimage.label(mask_bool) centroid = ndimage.center_of_mass(mask_bool) # Calculate bounding box rows = np.any(mask_bool, axis=1) cols = np.any(mask_bool, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] stats = { "area_pixels": int(roi_pixels), "area_percentage": float(roi_pixels / total_pixels * 100), "mean_intensity": float(np.mean(roi_intensities)), "std_intensity": float(np.std(roi_intensities)), "min_intensity": float(np.min(roi_intensities)), "max_intensity": float(np.max(roi_intensities)), "centroid": (float(centroid[1]), float(centroid[0])), # (x, y) "bounding_box": (int(cmin), int(rmin), int(cmax), int(rmax)), # (x1, y1, x2, y2) "num_components": num_features } # Add HU statistics for CT if modality == "CT": stats["mean_hu"] = stats["mean_intensity"] stats["std_hu"] = stats["std_intensity"] return stats except Exception as e: logger.error(f"Error calculating ROI statistics: {e}") return {"error": str(e)} # format_roi_statistics is now imported from segmentation module def process_with_roi_stats(image_file, prompt_text, modality, window_type): """Process image and return both segmentation and ROI statistics.""" if not is_model_loaded(): return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result is None: return None, "❌ Processing failed.", "" # Calculate ROI statistics stats = calculate_roi_statistics(image_file, mask, modality) stats_text = format_roi_statistics(stats) return result, "✅ Segmentation complete!", stats_text def process_with_point_prompt(image_file, point_x, point_y, modality, window_type, colormap='spring', transparency=0.5): """Process image with a point prompt for segmentation. Note: This simulates point-based prompting by using the point location as a seed for region-based segmentation. """ if not is_model_loaded(): return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Normalize img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] # Clamp point coordinates point_x = max(0, min(int(point_x), w - 1)) point_y = max(0, min(int(point_y), h - 1)) # Create a prompt based on the point location prompt_text = f"segment region at point" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Select mask containing the point for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Resize to image size mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 if mask_resized[point_y, point_x]: final_mask = mask_resized break # If no mask contains the point, use first mask if final_mask is None and len(masks) > 0: mask = masks[0] if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) final_mask = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 # Draw result with point marker plt.figure(figsize=(10, 10)) plt.imshow(pil_image) if final_mask is not None: plt.imshow(final_mask, alpha=transparency, cmap=colormap) # Draw point marker plt.scatter([point_x], [point_y], c='red', s=200, marker='+', linewidths=3) plt.scatter([point_x], [point_y], c='red', s=100, marker='o', facecolors='none', linewidths=2) plt.axis('off') plt.title(f"Point Prompt Segmentation at ({point_x}, {point_y})", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, f"✅ Point-based segmentation at ({point_x}, {point_y})" except Exception as e: logger.error(f"Error in point prompt processing: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, colormap='spring', transparency=0.5): """Process image with a bounding box prompt for segmentation.""" if not is_model_loaded(): return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] # Ensure box coordinates are valid x1, x2 = min(x1, x2), max(x1, x2) y1, y2 = min(y1, y2), max(y1, y2) x1, y1 = max(0, int(x1)), max(0, int(y1)) x2, y2 = min(w, int(x2)), min(h, int(y2)) prompt_text = "segment region in bounding box" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Combine all masks mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Resize to image size mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 mask_arrays.append(mask_resized) if len(mask_arrays) > 0: combined = np.any(mask_arrays, axis=0) # Create box mask and intersect box_mask = np.zeros((h, w), dtype=bool) box_mask[y1:y2, x1:x2] = True final_mask = combined & box_mask # Draw result with box plt.figure(figsize=(10, 10)) plt.imshow(pil_image) if final_mask is not None: plt.imshow(final_mask, alpha=transparency, cmap=colormap) # Draw bounding box rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='red', facecolor='none') plt.gca().add_patch(rect) plt.axis('off') plt.title(f"Box Prompt Segmentation [{x1}, {y1}, {x2}, {y2}]", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, f"✅ Box-based segmentation at [{x1}, {y1}, {x2}, {y2}]" except Exception as e: logger.error(f"Error in box prompt processing: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks=3): """Process image and return multiple mask candidates with confidence scores.""" if not is_model_loaded(): return [], "❌ Error: Model not loaded.", "" if image_file is None: return [], "⚠️ Please upload a medical image file.", "" try: file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') if not prompt_text or not prompt_text.strip(): prompt_text = "brain" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out sam_results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) results = [] mask_info = [] if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: masks = sam_results['masks'] # List of mask tensors scores = sam_results.get('scores', []) # List of scores num_available = len(masks) num_to_show = min(num_masks, num_available) colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma'] for i in range(num_to_show): mask = masks[i] if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Convert to boolean if mask_np.dtype != bool: mask_np = mask_np > 0.5 score = scores[i].item() if i < len(scores) and isinstance(scores[i], torch.Tensor) else (scores[i] if i < len(scores) else 0.5) # Create visualization plt.figure(figsize=(8, 8)) plt.imshow(pil_image) plt.imshow(mask_np, alpha=0.5, cmap=colormaps[i % len(colormaps)]) plt.axis('off') plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() results.append(output_path) mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask_np):,} pixels") status = f"✅ Generated {len(results)} mask candidate(s)" info = "\n".join(mask_info) if mask_info else "No mask information available" return results, status, info except Exception as e: logger.error(f"Error in multi-mask processing: {e}") import traceback traceback.print_exc() return [], f"❌ Error: {str(e)}", "" def export_to_nifti(image_file, mask, output_name="segmentation"): """Export segmentation mask to NIFTI format. Returns: str: Path to the exported NIFTI file, or None if export failed """ if not NIBABEL_AVAILABLE: return None, "⚠️ NIFTI export not available - nibabel not installed" if mask is None or not isinstance(mask, np.ndarray): return None, "⚠️ No valid mask to export" try: # Convert mask to appropriate format mask_data = mask.astype(np.float32) # Create NIFTI image # Use identity affine (1mm isotropic) affine = np.eye(4) # Try to get spacing from DICOM if available if image_file: file_path = image_file if isinstance(image_file, str) else str(image_file) if file_path.lower().endswith('.dcm'): try: ds = pydicom.dcmread(file_path, stop_before_pixels=True) pixel_spacing = getattr(ds, 'PixelSpacing', [1.0, 1.0]) slice_thickness = getattr(ds, 'SliceThickness', 1.0) affine[0, 0] = float(pixel_spacing[0]) affine[1, 1] = float(pixel_spacing[1]) affine[2, 2] = float(slice_thickness) except Exception as e: logger.debug(f"Could not extract spacing from DICOM: {e}") pass nifti_img = nib.Nifti1Image(mask_data, affine) # Save to temp file output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') output_path = output_file.name output_file.close() nib.save(nifti_img, output_path) return output_path, f"✅ Exported to NIFTI: {output_path}" except Exception as e: logger.error(f"Error exporting to NIFTI: {e}") return None, f"❌ Export failed: {str(e)}" def save_annotation(image_file, mask, prompt_text, modality, stats=None): """Save annotation to a JSON file for later loading.""" if mask is None: return None, "⚠️ No annotation to save" try: annotation = { "timestamp": datetime.now().isoformat(), "image_file": os.path.basename(image_file) if image_file else "unknown", "prompt": prompt_text, "modality": modality, "mask_shape": list(mask.shape), "mask_sum": int(np.sum(mask)), "mask_base64": None, # We'll store as binary in a separate file "statistics": stats if stats else {} } # Save mask as numpy file mask_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') mask_path = mask_file.name mask_file.close() np.savez_compressed(mask_path, mask=mask) # Save annotation JSON json_file = tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') json_path = json_file.name annotation["mask_file"] = mask_path json.dump(annotation, json_file, indent=2) json_file.close() # Create ZIP with both files zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: zf.write(json_path, 'annotation.json') zf.write(mask_path, 'mask.npz') zip_buffer.seek(0) zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') zip_path = zip_file.name zip_file.write(zip_buffer.read()) zip_file.close() return zip_path, f"✅ Annotation saved: {os.path.basename(zip_path)}" except Exception as e: logger.error(f"Error saving annotation: {e}") return None, f"❌ Save failed: {str(e)}" def load_annotation(annotation_file): """Load a previously saved annotation.""" if annotation_file is None: return None, None, "⚠️ No file selected" try: file_path = annotation_file if isinstance(annotation_file, str) else str(annotation_file) if file_path.endswith('.zip'): # Extract ZIP with zipfile.ZipFile(file_path, 'r') as zf: # Read annotation JSON with zf.open('annotation.json') as f: annotation = json.load(f) # Extract mask file mask_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') mask_temp.write(zf.read('mask.npz')) mask_temp.close() mask_data = np.load(mask_temp.name) mask = mask_data['mask'] info = f"📋 **Loaded Annotation**\n" info += f"Image: {annotation.get('image_file', 'unknown')}\n" info += f"Prompt: {annotation.get('prompt', 'N/A')}\n" info += f"Modality: {annotation.get('modality', 'N/A')}\n" info += f"Saved: {annotation.get('timestamp', 'N/A')}\n" info += f"Mask size: {annotation.get('mask_sum', 0):,} pixels" return mask, annotation, info else: return None, None, "⚠️ Invalid file format. Please upload a .zip annotation file." except Exception as e: logger.error(f"Error loading annotation: {e}") return None, None, f"❌ Load failed: {str(e)}" def visualize_loaded_annotation(image_file, annotation_file, colormap='spring', transparency=0.5): """Visualize a loaded annotation on the original image.""" mask, annotation, info = load_annotation(annotation_file) if mask is None: return None, info if image_file is None: return None, "⚠️ Please upload the original image to visualize" try: file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') # Resize mask if needed w, h = pil_image.size if mask.shape != (h, w): mask = np.array(Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))) > 127 # Visualize plt.figure(figsize=(10, 10)) plt.imshow(pil_image) plt.imshow(mask, alpha=transparency, cmap=colormap) plt.axis('off') plt.title(f"Loaded Annotation: {annotation.get('prompt', 'N/A')}", fontsize=12) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) plt.close() return output_path, info except Exception as e: logger.error(f"Error visualizing annotation: {e}") return None, f"❌ Visualization failed: {str(e)}" # Store last mask for export/save operations last_processed_mask = {"mask": None, "image_file": None, "prompt": None, "modality": None} def process_and_store_mask(image_file, prompt_text, modality, window_type): """Process image and store mask for export/save operations.""" result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result and mask is not None: last_processed_mask["mask"] = mask last_processed_mask["image_file"] = image_file last_processed_mask["prompt"] = prompt_text last_processed_mask["modality"] = modality # Calculate stats (using imported function from segmentation module) stats = calculate_roi_statistics(image_file, mask, modality) stats_text = format_roi_statistics(stats) return result, "✅ Segmentation complete! Ready for export.", stats_text else: return result, "❌ Processing failed.", "" def export_last_mask_nifti(): """Export the last processed mask to NIFTI.""" if last_processed_mask["mask"] is None: return None, "⚠️ No mask to export. Process an image first." return export_to_nifti( last_processed_mask["image_file"], last_processed_mask["mask"] ) # ============================================================================ # SAM-MEDICAL-IMAGING UTILITIES (Inspired by amine0110/SAM-Medical-Imaging) # Automatic Mask Generator, Advanced Transforms, Grid-based Segmentation # ============================================================================ class ResizeLongestSide: """ Resizes images to the longest side target length while maintaining aspect ratio. Inspired by SAM-Medical-Imaging transforms.py """ def __init__(self, target_length: int = 1024): self.target_length = target_length def apply_image(self, image: np.ndarray) -> np.ndarray: """Resize image maintaining aspect ratio.""" h, w = image.shape[:2] scale = self.target_length / max(h, w) new_h, new_w = int(h * scale), int(w * scale) pil_image = Image.fromarray(image) pil_image = pil_image.resize((new_w, new_h), Image.BILINEAR) return np.array(pil_image) def apply_coords(self, coords: np.ndarray, original_size: tuple) -> np.ndarray: """Resize coordinates to match resized image.""" old_h, old_w = original_size scale = self.target_length / max(old_h, old_w) return coords * scale def apply_boxes(self, boxes: np.ndarray, original_size: tuple) -> np.ndarray: """Resize bounding boxes to match resized image.""" boxes = boxes.copy() boxes[..., :2] = self.apply_coords(boxes[..., :2], original_size) boxes[..., 2:] = self.apply_coords(boxes[..., 2:], original_size) return boxes # generate_grid_points is now imported from segmentation module def automatic_mask_generator(image_file, modality, window_type, points_per_side=16, min_mask_area=100, colormap='tab20', progress=gr.Progress()): """ Automatic Mask Generator (AMG) - Generate masks for entire image without prompts. Uses a grid of points to query the model and generates masks automatically. Inspired by SAM-Medical-Imaging's amg.py """ if not is_model_loaded(): return None, "❌ Error: Model not loaded.", "" if image_file is None: return None, "⚠️ Please upload a medical image file.", "" try: progress(0.1, desc="Loading image...") # Load and preprocess image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Apply windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) pil_image = Image.fromarray(img_uint8).convert('RGB') else: pil_image = Image.open(file_path).convert('RGB') img_array = np.array(pil_image) h, w = img_array.shape[:2] progress(0.2, desc="Generating grid points...") # Generate grid of points grid_points = generate_grid_points((h, w), points_per_side) total_points = len(grid_points) # Collect all masks all_masks = [] all_scores = [] progress(0.3, desc=f"Processing {total_points} points...") # Use different prompts to generate diverse masks prompts = ["anatomical structure", "region", "tissue", "organ", "segment"] for prompt_idx, prompt in enumerate(prompts): progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...") try: # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out sam_results = run_sam3_inference(pil_image, prompt, threshold=0.1, mask_threshold=0.0) if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: masks = sam_results['masks'] # List of mask tensors for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) # Convert to boolean if mask_np.dtype != bool: mask_np = mask_np > 0.5 # Filter by minimum area mask_area = np.sum(mask_np) if mask_area >= min_mask_area: # Resize mask to image size mask_resized = np.array( Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h)) ) > 127 all_masks.append(mask_resized) all_scores.append(mask_area) except Exception as e: logger.error(f"Error with prompt '{prompt}': {e}") continue progress(0.85, desc="Combining masks...") if not all_masks: return None, "⚠️ No masks generated. Try different parameters.", "" # Remove duplicate/overlapping masks using Non-Maximum Suppression unique_masks = [] for mask in all_masks: is_duplicate = False for existing in unique_masks: # Check IoU intersection = np.logical_and(mask, existing).sum() union = np.logical_or(mask, existing).sum() iou = intersection / union if union > 0 else 0 if iou > 0.8: # High overlap threshold is_duplicate = True break if not is_duplicate: unique_masks.append(mask) progress(0.9, desc="Creating visualization...") # Create visualization with all masks plt.figure(figsize=(12, 12)) plt.imshow(pil_image) # Create colored overlay for all masks cmap = plt.cm.get_cmap(colormap) for idx, mask in enumerate(unique_masks): color = cmap(idx / max(len(unique_masks), 1))[:3] colored_mask = np.zeros((*mask.shape, 4)) colored_mask[mask] = (*color, 0.4) plt.imshow(colored_mask) plt.axis('off') plt.title(f"Automatic Mask Generation: {len(unique_masks)} regions detected", fontsize=14) output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150) plt.close() progress(1.0, desc="Complete!") # Generate info text info_text = f"📊 **AMG Results**\n\n" info_text += f"**Total Regions:** {len(unique_masks)}\n" info_text += f"**Grid Points:** {points_per_side}x{points_per_side} = {points_per_side**2}\n" info_text += f"**Min Area Filter:** {min_mask_area} pixels\n\n" for idx, mask in enumerate(unique_masks[:10]): # Show first 10 area = np.sum(mask) percentage = area / (h * w) * 100 info_text += f"Region {idx+1}: {area:,} pixels ({percentage:.2f}%)\n" if len(unique_masks) > 10: info_text += f"\n... and {len(unique_masks) - 10} more regions" return output_path, f"✅ AMG Complete! Found {len(unique_masks)} regions.", info_text except Exception as e: logger.error(f"Error in AMG: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}", "" def process_with_advanced_transforms(image_file, prompt_text, modality, window_type, target_size=1024, apply_clahe=False, clahe_clip=2.0, colormap='spring', transparency=0.5): """ Process image with advanced transforms from SAM-Medical-Imaging. - ResizeLongestSide: Maintains aspect ratio - CLAHE: Contrast Limited Adaptive Histogram Equalization (optional) """ if not is_model_loaded(): return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept # Apply windowing if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) else: img = Image.open(file_path) if img.mode != 'L': img = img.convert('L') img_uint8 = np.array(img) original_size = img_uint8.shape[:2] # Apply CLAHE if requested if apply_clahe: try: from scipy.ndimage import uniform_filter # Simple CLAHE-like enhancement using local contrast local_mean = uniform_filter(img_uint8.astype(float), size=50) local_std = np.sqrt(uniform_filter(img_uint8.astype(float)**2, size=50) - local_mean**2 + 1e-8) # Enhance contrast enhanced = (img_uint8 - local_mean) / (local_std + 1e-8) * clahe_clip enhanced = np.clip(enhanced * 30 + 128, 0, 255).astype(np.uint8) img_uint8 = enhanced except Exception as e: logger.warning(f"CLAHE enhancement failed: {e}") # Apply ResizeLongestSide transform transform = ResizeLongestSide(target_size) if len(img_uint8.shape) == 2: img_uint8_3ch = np.stack([img_uint8] * 3, axis=-1) else: img_uint8_3ch = img_uint8 img_resized = transform.apply_image(img_uint8_3ch) pil_image = Image.fromarray(img_resized) if not prompt_text or not prompt_text.strip(): prompt_text = "brain" # Process with SAM 3 - using helper function # Lower thresholds for medical images to ensure detections are not filtered out results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) final_mask = None if results and 'masks' in results and results['masks'] is not None: masks = results['masks'] # Combine all masks mask_arrays = [] for mask in masks: if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy() else: mask_np = np.array(mask) mask_arrays.append(mask_np) if len(mask_arrays) > 0: final_mask = np.any(mask_arrays, axis=0) # Visualize plt.figure(figsize=(12, 6)) # Original plt.subplot(1, 2, 1) plt.imshow(Image.fromarray(img_uint8_3ch[:, :, 0] if len(img_uint8_3ch.shape) == 3 else img_uint8_3ch), cmap='gray') plt.title(f"Original ({original_size[0]}x{original_size[1]})", fontsize=10) plt.axis('off') # Processed with mask plt.subplot(1, 2, 2) plt.imshow(pil_image) if final_mask is not None: # Resize mask to match processed image mask_h, mask_w = final_mask.shape img_h, img_w = img_resized.shape[:2] if (mask_h, mask_w) != (img_h, img_w): final_mask = np.array( Image.fromarray(final_mask.astype(np.uint8) * 255).resize((img_w, img_h)) ) > 127 plt.imshow(final_mask, alpha=transparency, cmap=colormap) plt.title(f"Transformed ({target_size}px) + Segmentation", fontsize=10) plt.axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=100) plt.close() status = f"✅ Processed with ResizeLongestSide({target_size})" if apply_clahe: status += f" + CLAHE(clip={clahe_clip})" return output_path, status except Exception as e: logger.error(f"Error in advanced transforms: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def edge_based_segmentation(image_file, modality, window_type, edge_threshold=50, dilation_size=3, colormap='spring', transparency=0.5): """ Edge-based automatic segmentation using Sobel/Canny-like detection. Useful for finding boundaries in medical images. """ if image_file is None: return None, "⚠️ Please upload a medical image file." try: # Load image file_path = image_file if isinstance(image_file, str) else str(image_file) file_ext = os.path.splitext(file_path)[1].lower() if file_ext == '.dcm': ds = pydicom.dcmread(file_path) img_array = ds.pixel_array.astype(np.float32) slope = getattr(ds, 'RescaleSlope', 1) intercept = getattr(ds, 'RescaleIntercept', 0) img_array = img_array * slope + intercept if modality == "CT": if window_type == "Brain (Grey Matter)": level, width = 40, 80 elif window_type == "Bone (Skull)": level, width = 500, 2000 else: level, width = 40, 400 img_min = level - (width / 2) img_max = level + (width / 2) else: img_min = np.percentile(img_array, 1) img_max = np.percentile(img_array, 99) img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) img_uint8 = (img_norm * 255).astype(np.uint8) else: img = Image.open(file_path).convert('L') img_uint8 = np.array(img) # Compute Sobel edges from scipy.ndimage import sobel, binary_dilation, binary_fill_holes # Sobel edge detection dx = sobel(img_uint8.astype(float), axis=1) dy = sobel(img_uint8.astype(float), axis=0) edges = np.hypot(dx, dy) # Threshold edges edge_mask = edges > edge_threshold # Dilate edges to connect nearby components if dilation_size > 0: struct = np.ones((dilation_size, dilation_size)) edge_mask = binary_dilation(edge_mask, structure=struct) # Fill holes filled_mask = binary_fill_holes(edge_mask) # Label connected components labeled, num_features = ndimage.label(filled_mask) # Create RGB image for display pil_image = Image.fromarray(img_uint8).convert('RGB') # Visualize plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.imshow(img_uint8, cmap='gray') plt.title("Original Image", fontsize=10) plt.axis('off') plt.subplot(1, 3, 2) plt.imshow(edges, cmap='hot') plt.title(f"Edge Detection (threshold={edge_threshold})", fontsize=10) plt.axis('off') plt.subplot(1, 3, 3) plt.imshow(pil_image) plt.imshow(filled_mask, alpha=transparency, cmap=colormap) plt.title(f"Segmentation ({num_features} regions)", fontsize=10) plt.axis('off') plt.tight_layout() output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') output_path = output_file.name output_file.close() plt.savefig(output_path, bbox_inches='tight', dpi=100) plt.close() return output_path, f"✅ Edge-based segmentation complete! Found {num_features} regions." except Exception as e: logger.error(f"Error in edge segmentation: {e}") import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" def save_last_annotation(): """Save the last processed annotation.""" if last_processed_mask["mask"] is None: return None, "⚠️ No annotation to save. Process an image first." stats = calculate_roi_statistics( last_processed_mask["image_file"], last_processed_mask["mask"], last_processed_mask["modality"] ) return save_annotation( last_processed_mask["image_file"], last_processed_mask["mask"], last_processed_mask["prompt"], last_processed_mask["modality"], stats ) # Create Gradio Interface # Set demo_file_path after verifying file exists demo_file_path = DEMO_DICOM_PATH if demo_file_available and os.path.exists(DEMO_DICOM_PATH) else None def load_demo_file(): """Load the demo DICOM file.""" if demo_file_path and os.path.exists(demo_file_path): return demo_file_path, f"✅ Demo file loaded: {demo_file_path}\nReady to segment!" else: return None, "⚠️ Demo file not found. Please upload a medical image file (DICOM, PNG, or JPG)." def process_with_status(image_file, prompt_text, modality, window_type): """Wrapper function to update status during processing.""" if not is_model_loaded(): return None, "❌ Error: Model not loaded." if image_file is None: return None, "⚠️ Please upload a medical image file (DICOM, PNG, or JPG) or load the demo file." result = process_medical_image(image_file, prompt_text, modality, window_type) if result is None: return None, "❌ Processing failed. Check console for error details." else: return result, "✅ Segmentation complete!" def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type): """Process image and compare with ground truth segmentation mask.""" if not is_model_loaded(): return None, None, 0.0, 0.0, "❌ Error: Model not loaded." if image_file is None: return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file." if gt_mask_file is None: return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file." # Process image and get mask result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) if result is None or pred_mask is None: return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details." # Compare with ground truth comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file) if comparison_path: status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}" return result, comparison_path, dice_score, iou_score, status else: return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed." def process_sequence(image_files, prompt_text, modality, window_type): """Process multiple images from the same subject and return gallery of results.""" if not is_model_loaded(): return [], "❌ Error: Model not loaded." if not image_files: return [], "⚠️ Please upload medical image files (DICOM, PNG, or JPG)." # Handle single file or list of files if isinstance(image_files, str): image_files = [image_files] results = [] status_messages = [] for idx, image_file in enumerate(image_files): if image_file is None: continue status_msg = f"Processing image {idx + 1}/{len(image_files)}..." status_messages.append(status_msg) result = process_medical_image(image_file, prompt_text, modality, window_type) if result: results.append(result) status_messages.append(f"✅ Image {idx + 1} segmented successfully") else: status_messages.append(f"❌ Failed to process image {idx + 1}") if results: status = f"✅ Processed {len(results)}/{len(image_files)} images successfully!\n" + "\n".join(status_messages) return results, status else: return [], "❌ No images were processed successfully. Check console for error details." # Store processed results for interactive viewer (now using cache_manager) # processed_results_cache is imported from cache_manager # extract_subject_id and group_images_by_subject are now imported from utils module def detect_subjects(image_files): """Detect and return subject groups from uploaded files.""" if not image_files: return gr.Dropdown(choices=[], value=None), "No files uploaded" subject_groups = group_images_by_subject(image_files) if not subject_groups: return gr.Dropdown(choices=[], value=None), "No subjects detected" choices = [] status_msg = f"✅ Detected {len(subject_groups)} subject(s):\n\n" for subject_id, info in sorted(subject_groups.items()): num_files = len(info['files']) confidence = info['confidence'] sources = ', '.join(info['sources']) # Add confidence indicator if confidence == 'high': confidence_icon = "✅" confidence_text = "HIGH (DICOM metadata - very reliable)" elif confidence == 'medium': confidence_icon = "⚠️" confidence_text = "MEDIUM (folder/filename pattern - likely same patient)" else: confidence_icon = "⚠️⚠️" confidence_text = "LOW (filename-based - verify manually)" choices.append(f"{subject_id} ({num_files} slices)") status_msg += f"{confidence_icon} **{subject_id}**: {num_files} slices\n" status_msg += f" Confidence: {confidence_text}\n" status_msg += f" Source: {sources}\n\n" # Add warning for low confidence low_confidence_count = sum(1 for info in subject_groups.values() if info['confidence'] == 'low') if low_confidence_count > 0: status_msg += f"⚠️ **Warning**: {low_confidence_count} subject(s) detected with LOW confidence.\n" status_msg += "Please verify these are actually the same patient before proceeding.\n" return gr.Dropdown(choices=choices, value=choices[0] if choices else None), status_msg def process_slices_for_viewer(image_files, selected_subject, prompt_text, modality, window_type): """Process all slices for selected subject and cache results for interactive viewing.""" if not is_model_loaded(): return None, 0, "❌ Error: Model not loaded.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" if not image_files: return None, 0, "⚠️ Please upload medical image files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Group by subject subject_groups = group_images_by_subject(image_files) if not subject_groups: return None, 0, "⚠️ Could not detect subjects in uploaded files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Extract subject ID from selection (format: "subject_id (N slices)") if selected_subject: subject_id = selected_subject.split(" (")[0] else: # Use first subject if none selected subject_id = list(subject_groups.keys())[0] if subject_id not in subject_groups: return None, 0, f"⚠️ Subject '{subject_id}' not found.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" # Get files for selected subject subject_info = subject_groups[subject_id] subject_files = subject_info['files'] confidence = subject_info['confidence'] # Add confidence warning confidence_warning = "" if confidence == 'low': confidence_warning = "\n⚠️ LOW CONFIDENCE: These files may not be from the same patient. Please verify!" elif confidence == 'medium': confidence_warning = "\n⚠️ MEDIUM CONFIDENCE: Likely same patient, but verify if critical." results = [] status_messages = [] for idx, image_file in enumerate(subject_files): status_msg = f"Processing slice {idx + 1}/{len(subject_files)}..." status_messages.append(status_msg) result = process_medical_image(image_file, prompt_text, modality, window_type) if result: results.append(result) status_messages.append(f"✅ Slice {idx + 1} processed") else: status_messages.append(f"❌ Failed to process slice {idx + 1}") if results: # Cache results with a unique key including subject ID cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" processed_results_cache[cache_key] = results max_slices = len(results) - 1 status = f"✅ Processed {len(results)}/{len(subject_files)} slices for {subject_id}!\nUse slider or buttons to navigate.{confidence_warning}" slice_info = f"Slice 1/{len(results)} ({subject_id})" # Update subject dropdown choices choices = [] for sid, info in sorted(subject_groups.items()): marker = "✓" if sid == subject_id else "" num_files = len(info['files']) choices.append(f"{marker} {sid} ({num_files} slices)") return results[0], max_slices, status, slice_info, gr.Dropdown(choices=choices, value=choices[0] if choices else None), f"Viewing: {subject_id}" else: return None, 0, "❌ No slices were processed successfully.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" def navigate_slice(slice_idx, image_files, selected_subject, prompt_text, modality, window_type): """Navigate to a specific slice in the sequence.""" if not image_files: return None, "No slices loaded" # Group by subject and get selected subject's files subject_groups = group_images_by_subject(image_files) if selected_subject: subject_id = selected_subject.split(" (")[0] else: subject_id = list(subject_groups.keys())[0] if subject_groups else None if not subject_id or subject_id not in subject_groups: return None, "No slices loaded" subject_info = subject_groups[subject_id] subject_files = subject_info['files'] slice_idx = int(slice_idx) cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" if cache_key in processed_results_cache: results = processed_results_cache[cache_key] if 0 <= slice_idx < len(results): slice_info = f"Slice {slice_idx + 1}/{len(results)} ({subject_id})" return results[slice_idx], slice_info # If not cached, process on the fly (fallback) if 0 <= slice_idx < len(subject_files): result = process_medical_image(subject_files[slice_idx], prompt_text, modality, window_type) if result: slice_info = f"Slice {slice_idx + 1}/{len(subject_files)} ({subject_id})" return result, slice_info return None, f"Invalid slice index: {slice_idx}" with gr.Blocks() as demo: gr.Markdown("# 🏥 NeuroSAM 3: Medical Image Segmentation") demo_info = "" if demo_file_path: demo_info = f"\n\n**📁 Demo File Available:** A sample DICOM file is ready: `{demo_file_path}`\nClick 'Load Demo File' button below to use it!" gr.Markdown(f""" Upload a medical image (DICOM .dcm, PNG, or JPG) and type what you want to find (e.g., 'brain', 'tumor', 'skull'). {demo_info} **Instructions:** 1. Upload a medical image file: - DICOM (.dcm) files for CT or MRI scans - PNG/JPG files (e.g., from Kaggle datasets like brain MRI images) 2. Enter a text prompt describing what to segment 3. Select the imaging modality (CT or MRI) 4. Choose the windowing strategy (for CT images) 5. Click "Segment Structure" to process **Supported Formats:** - DICOM (.dcm) - Standard medical imaging format - PNG/JPG - Standard image formats (works with Kaggle brain MRI datasets) """) with gr.Tabs(): with gr.Tab("Single Image"): with gr.Row(): with gr.Column(): file_input = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath", value=demo_file_path ) load_demo_btn = gr.Button( "📁 Load Demo File", variant="secondary", size="sm", visible=bool(demo_file_path) ) text_input = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) submit_btn = gr.Button("Segment Structure", variant="primary", size="lg") with gr.Column(): image_output = gr.Image( label="Segmentation Result", type="filepath" ) gr.Markdown("### Status") status_text = gr.Textbox( label="Processing Status", value="Ready. Upload a medical image file (DICOM, PNG, or JPG) to begin.", interactive=False ) with gr.Tab("Interactive Slice Viewer"): gr.Markdown("**Scroll through multiple slices/images from the same subject interactively**") gr.Markdown(""" **📋 Subject Detection:** The app automatically detects subject/patient IDs from: - Folder names (e.g., `subject_001/`, `patient_123/`) - Filenames (e.g., `subject_001_slice_01.png`, `patient_123.dcm`) - DICOM metadata (PatientID, StudyInstanceUID) **💡 Tip:** Upload images organized by subject folders for best results! """) with gr.Row(): with gr.Column(): files_input = gr.File( label="Upload Multiple Images/Slices (Select multiple files)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) subject_dropdown = gr.Dropdown( label="Select Subject/Patient", choices=[], value=None, interactive=True, info="Select which subject's slices to view (auto-detected from filenames/folders)" ) text_input_batch = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown_batch = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown_batch = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) detect_subjects_btn = gr.Button("🔍 Detect Subjects", variant="secondary", size="sm") submit_batch_btn = gr.Button("Process All Slices", variant="primary", size="lg") gr.Markdown("---") gr.Markdown("### 🎛️ Slice Navigator") slice_slider = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Slice Number", info="Use slider or arrow keys to navigate through slices", interactive=False ) with gr.Row(): prev_btn = gr.Button("⬆️ Previous Slice", size="sm") next_btn = gr.Button("⬇️ Next Slice", size="sm") auto_play_btn = gr.Button("▶️ Auto-play", size="sm") with gr.Column(): current_slice_output = gr.Image( label="Current Slice Segmentation", type="filepath", height=600 ) gr.Markdown("### Slice Info") slice_info_text = gr.Textbox( label="Current Slice", value="No slices loaded", interactive=False ) subject_info_text = gr.Textbox( label="Subject Info", value="", interactive=False, visible=False ) gr.Markdown("### Status") status_batch_text = gr.Textbox( label="Processing Status", value="Ready. Upload multiple medical image files to process a sequence.", interactive=False, lines=4 ) with gr.Tab("Gallery View"): gr.Markdown("**View all segmentations in a gallery grid**") with gr.Row(): with gr.Column(): files_input_gallery = gr.File( label="Upload Multiple Images (Select multiple files)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) text_input_gallery = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes" ) with gr.Row(): modality_dropdown_gallery = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI" ) window_dropdown_gallery = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)" ) submit_gallery_btn = gr.Button("Process & Show Gallery", variant="primary", size="lg") with gr.Column(): gallery_output = gr.Gallery( label="Segmentation Gallery", show_label=True, elem_id="gallery", columns=2, rows=2, height="auto" ) status_gallery_text = gr.Textbox( label="Status", value="Ready. Upload multiple images to view in gallery.", interactive=False ) with gr.Tab("Compare with Ground Truth"): gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**") with gr.Row(): with gr.Column(): file_input_gt = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gt_mask_input = gr.File( label="Upload Ground Truth Mask (PNG, JPG)", file_types=[".png", ".jpg", ".jpeg"], type="filepath" ) text_input_gt = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_dropdown_gt = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_dropdown_gt = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg") with gr.Column(): image_output_gt = gr.Image( label="SAM 3 Segmentation", type="filepath" ) comparison_output = gr.Image( label="Comparison: SAM 3 vs Ground Truth", type="filepath" ) gr.Markdown("### Metrics") dice_score_text = gr.Textbox( label="Dice Score", value="--", interactive=False ) iou_score_text = gr.Textbox( label="IoU Score", value="--", interactive=False ) gr.Markdown("### Status") status_gt_text = gr.Textbox( label="Processing Status", value="Ready. Upload image and ground truth mask to compare.", interactive=False, lines=3 ) # NEW ENHANCED TABS with gr.Tab("✨ Enhanced Single Image"): gr.Markdown("**Enhanced version with image adjustments, visualization options, and progress tracking**") with gr.Row(): with gr.Column(): file_input_enh = gr.File( label="Upload Medical Image (DICOM .dcm, PNG, JPG)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath", value=demo_file_path ) text_input_enh = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes", info="Describe what anatomical structure or region you want to segment" ) with gr.Row(): modality_enh = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI", info="Select the imaging modality" ) window_enh = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)", info="CT windowing preset (ignored for MRI)" ) # NEW: Image adjustment controls gr.Markdown("### 🎨 Image Adjustments") brightness_slider = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Brightness", info="Adjust image brightness (0.5 = darker, 2.0 = brighter)" ) contrast_slider = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Contrast", info="Adjust image contrast (0.5 = lower contrast, 2.0 = higher contrast)" ) # NEW: Visualization options gr.Markdown("### 🎭 Visualization Options") colormap_dropdown = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], label="Mask Colormap", value="spring", info="Color scheme for segmentation overlay" ) transparency_slider = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Mask Transparency", info="Transparency of segmentation overlay (0.0 = opaque, 1.0 = transparent)" ) submit_enh_btn = gr.Button("Segment with Progress", variant="primary", size="lg") with gr.Column(): # Progress indicator progress_text = gr.Textbox( label="Progress", value="Ready. Upload a medical image file to begin.", interactive=False ) image_output_enh = gr.Image( label="Segmentation Result", type="filepath" ) # NEW: Download button download_output = gr.File( label="Download Result", visible=True ) # NEW: Metrics display gr.Markdown("### 📊 Metrics") metrics_text = gr.Textbox( label="Segmentation Info", value="", lines=3, interactive=False ) with gr.Tab("✨ Enhanced Batch Processing"): gr.Markdown("**Enhanced batch processing with ZIP download and progress tracking**") with gr.Row(): with gr.Column(): files_input_enh_batch = gr.File( label="Upload Multiple Images for Batch Processing", file_types=[".dcm", ".png", ".jpg", ".jpeg"], file_count="multiple", type="filepath" ) text_input_enh_batch = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull, eyes" ) with gr.Row(): modality_enh_batch = gr.Dropdown( ["CT", "MRI"], label="Modality", value="MRI" ) window_enh_batch = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing Strategy (CT only)", value="Brain (Grey Matter)" ) # Image adjustments gr.Markdown("### 🎨 Image Adjustments") brightness_slider_batch = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Brightness" ) contrast_slider_batch = gr.Slider( 0.5, 2.0, value=1.0, step=0.1, label="Contrast" ) # Visualization options gr.Markdown("### 🎭 Visualization Options") colormap_dropdown_batch = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], label="Mask Colormap", value="spring" ) transparency_slider_batch = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Mask Transparency" ) submit_enh_batch_btn = gr.Button("Process Batch with Progress", variant="primary", size="lg") with gr.Column(): gallery_output_enh = gr.Gallery( label="Segmentation Gallery", show_label=True, elem_id="gallery", columns=2, rows=2, height="auto" ) # Batch download ZIP batch_download_output = gr.File( label="Download All Results (ZIP)", visible=True ) status_enh_batch_text = gr.Textbox( label="Status", value="Ready. Upload multiple images to process in batch.", interactive=False, lines=4 ) # NEW: Point/Box Prompts Tab with gr.Tab("🎯 Point/Box Prompts"): gr.Markdown(""" **Interactive Point and Box-based Segmentation** Use precise point clicks or bounding boxes to guide the segmentation. - **Point Prompt**: Click on the region you want to segment - **Box Prompt**: Define a bounding box around the region of interest """) with gr.Tabs(): with gr.Tab("Point Prompt"): with gr.Row(): with gr.Column(): file_input_point = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gr.Markdown("### Point Coordinates") with gr.Row(): point_x = gr.Number(label="X coordinate", value=128, precision=0) point_y = gr.Number(label="Y coordinate", value=128, precision=0) with gr.Row(): modality_point = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_point = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) with gr.Row(): colormap_point = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_point = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") submit_point_btn = gr.Button("Segment at Point", variant="primary") with gr.Column(): output_point = gr.Image(label="Point Segmentation Result", type="filepath") status_point = gr.Textbox(label="Status", interactive=False) with gr.Tab("Box Prompt"): with gr.Row(): with gr.Column(): file_input_box = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) gr.Markdown("### Bounding Box Coordinates") with gr.Row(): box_x1 = gr.Number(label="X1 (left)", value=50, precision=0) box_y1 = gr.Number(label="Y1 (top)", value=50, precision=0) with gr.Row(): box_x2 = gr.Number(label="X2 (right)", value=200, precision=0) box_y2 = gr.Number(label="Y2 (bottom)", value=200, precision=0) with gr.Row(): modality_box = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_box = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) with gr.Row(): colormap_box = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_box = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") submit_box_btn = gr.Button("Segment in Box", variant="primary") with gr.Column(): output_box = gr.Image(label="Box Segmentation Result", type="filepath") status_box = gr.Textbox(label="Status", interactive=False) # NEW: ROI Statistics & Export Tab with gr.Tab("📊 ROI Statistics & Export"): gr.Markdown(""" **ROI Statistics and Export Options** Process an image and get detailed statistics about the segmented region: - Area (pixels and percentage) - Intensity statistics (mean, std, min, max) - Centroid and bounding box - Export to NIFTI format for medical imaging software - Save/Load annotations for later use """) with gr.Row(): with gr.Column(): file_input_stats = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_stats = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_stats = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_stats = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) submit_stats_btn = gr.Button("Process & Calculate Statistics", variant="primary") gr.Markdown("### Export Options") with gr.Row(): export_nifti_btn = gr.Button("📥 Export to NIFTI", size="sm") save_annotation_btn = gr.Button("💾 Save Annotation", size="sm") with gr.Column(): output_stats = gr.Image(label="Segmentation Result", type="filepath") status_stats = gr.Textbox(label="Status", interactive=False) gr.Markdown("### 📊 ROI Statistics") roi_stats_text = gr.Markdown(value="*Process an image to see statistics*") nifti_download = gr.File(label="Download NIFTI", visible=True) annotation_download = gr.File(label="Download Annotation", visible=True) gr.Markdown("---") gr.Markdown("### Load Saved Annotation") with gr.Row(): with gr.Column(): annotation_upload = gr.File( label="Upload Annotation (.zip)", file_types=[".zip"], type="filepath" ) original_image_upload = gr.File( label="Upload Original Image (for visualization)", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) load_annotation_btn = gr.Button("Load & Visualize Annotation", variant="secondary") with gr.Column(): loaded_annotation_output = gr.Image(label="Loaded Annotation", type="filepath") loaded_annotation_info = gr.Markdown(value="*Upload an annotation file to load*") # NEW: Multi-Mask Output Tab with gr.Tab("🎭 Multi-Mask Output"): gr.Markdown(""" **Generate Multiple Mask Candidates** SAM can generate multiple segmentation hypotheses with confidence scores. This is useful when the segmentation is ambiguous or you want to compare alternatives. """) with gr.Row(): with gr.Column(): file_input_multi = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_multi = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_multi = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_multi = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) num_masks_slider = gr.Slider(1, 5, value=3, step=1, label="Number of Masks") submit_multi_btn = gr.Button("Generate Multiple Masks", variant="primary") with gr.Column(): gallery_multi = gr.Gallery( label="Mask Candidates", show_label=True, columns=2, rows=2, height="auto" ) status_multi = gr.Textbox(label="Status", interactive=False) mask_info_multi = gr.Textbox(label="Mask Information", lines=5, interactive=False) # NEW: SAM-Medical-Imaging Inspired Tabs with gr.Tab("🔬 Automatic Mask Generator"): gr.Markdown(""" **Automatic Mask Generator (AMG)** Inspired by [SAM-Medical-Imaging](https://github.com/amine0110/SAM-Medical-Imaging) - automatically segment the entire image without needing specific prompts. - Uses multiple prompts internally to discover all regions - Filters and deduplicates overlapping masks - Great for exploratory analysis of medical images """) with gr.Row(): with gr.Column(): file_input_amg = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) with gr.Row(): modality_amg = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_amg = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### AMG Parameters") points_per_side = gr.Slider( 8, 32, value=16, step=4, label="Grid Density", info="Higher = more detailed but slower" ) min_mask_area = gr.Slider( 50, 1000, value=100, step=50, label="Minimum Mask Area (pixels)", info="Filter out small noise regions" ) colormap_amg = gr.Dropdown( ["tab20", "tab10", "Set1", "Set2", "Paired", "rainbow"], label="Colormap for Regions", value="tab20" ) submit_amg_btn = gr.Button("🔬 Run Automatic Segmentation", variant="primary", size="lg") with gr.Column(): output_amg = gr.Image(label="AMG Result", type="filepath") status_amg = gr.Textbox(label="Status", interactive=False) info_amg = gr.Markdown(value="*Run AMG to see region details*") with gr.Tab("🔧 Advanced Transforms"): gr.Markdown(""" **Advanced Image Transforms** Apply preprocessing transforms from SAM-Medical-Imaging before segmentation: - **ResizeLongestSide**: Resize maintaining aspect ratio (optimal for SAM) - **CLAHE-like Enhancement**: Enhance local contrast for better visibility """) with gr.Row(): with gr.Column(): file_input_transform = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) text_input_transform = gr.Textbox( label="Text Prompt", value="brain", placeholder="e.g. brain, tumor, skull" ) with gr.Row(): modality_transform = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_transform = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### Transform Settings") target_size_slider = gr.Slider( 256, 2048, value=1024, step=128, label="Target Size (ResizeLongestSide)", info="Resize image's longest side to this value" ) apply_clahe_checkbox = gr.Checkbox( label="Apply CLAHE-like Enhancement", value=False, info="Enhance local contrast (useful for low-contrast images)" ) clahe_clip_slider = gr.Slider( 1.0, 4.0, value=2.0, step=0.5, label="CLAHE Clip Limit", info="Higher = more contrast enhancement" ) with gr.Row(): colormap_transform = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_transform = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Transparency" ) submit_transform_btn = gr.Button("🔧 Process with Transforms", variant="primary") with gr.Column(): output_transform = gr.Image(label="Transform + Segmentation Result", type="filepath") status_transform = gr.Textbox(label="Status", interactive=False) with gr.Tab("🌊 Edge-Based Segmentation"): gr.Markdown(""" **Edge-Based Automatic Segmentation** Uses Sobel edge detection to find boundaries in medical images. Works without AI model - useful for comparison or when model fails. - Detects edges using gradient-based methods - Fills regions within boundaries - Identifies distinct anatomical structures """) with gr.Row(): with gr.Column(): file_input_edge = gr.File( label="Upload Medical Image", file_types=[".dcm", ".png", ".jpg", ".jpeg"], type="filepath" ) with gr.Row(): modality_edge = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") window_edge = gr.Dropdown( ["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], label="Windowing", value="Brain (Grey Matter)" ) gr.Markdown("### Edge Detection Parameters") edge_threshold_slider = gr.Slider( 10, 150, value=50, step=10, label="Edge Threshold", info="Lower = more edges detected" ) dilation_size_slider = gr.Slider( 0, 10, value=3, step=1, label="Dilation Size", info="Connect nearby edges (0 = no dilation)" ) with gr.Row(): colormap_edge = gr.Dropdown( ["spring", "cool", "hot", "viridis", "plasma"], label="Colormap", value="spring" ) transparency_edge = gr.Slider( 0.0, 1.0, value=0.5, step=0.1, label="Transparency" ) submit_edge_btn = gr.Button("🌊 Run Edge Segmentation", variant="primary") with gr.Column(): output_edge = gr.Image(label="Edge Segmentation Result", type="filepath") status_edge = gr.Textbox(label="Status", interactive=False) # Single image processing load_demo_btn.click( fn=load_demo_file, inputs=[], outputs=[file_input, status_text] ) submit_btn.click( fn=process_with_status, inputs=[file_input, text_input, modality_dropdown, window_dropdown], outputs=[image_output, status_text] ) # Detect subjects when files are uploaded detect_subjects_btn.click( fn=detect_subjects, inputs=[files_input], outputs=[subject_dropdown, status_batch_text] ) # Interactive slice viewer submit_batch_btn.click( fn=process_slices_for_viewer, inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_slider, status_batch_text, slice_info_text, subject_dropdown, subject_info_text] ).then( lambda max_val: gr.Slider(maximum=max(max_val, 1), interactive=True), inputs=[slice_slider], outputs=[slice_slider] ) def update_slice(slice_num, files, selected_subject, prompt, mod, window): result, info = navigate_slice(int(slice_num), files, selected_subject, prompt, mod, window) return result, info slice_slider.change( fn=update_slice, inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_info_text] ) def prev_slice(current, files, selected_subject, prompt, mod, window): new_val = max(0, current - 1) result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) return new_val, result, info def next_slice(current, max_val, files, selected_subject, prompt, mod, window): new_val = min(max_val, current + 1) result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) return new_val, result, info prev_btn.click( fn=prev_slice, inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[slice_slider, current_slice_output, slice_info_text] ) next_btn.click( fn=next_slice, inputs=[slice_slider, slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[slice_slider, current_slice_output, slice_info_text] ) # Gallery view submit_gallery_btn.click( fn=process_sequence, inputs=[files_input_gallery, text_input_gallery, modality_dropdown_gallery, window_dropdown_gallery], outputs=[gallery_output, status_gallery_text] ) # Ground truth comparison submit_gt_btn.click( fn=process_with_ground_truth, inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt], outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text] ) # Enhanced single image processing def process_enhanced_wrapper(image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency, progress=gr.Progress()): """Wrapper to return both image and download file.""" result, status, metrics = process_with_progress( image_file, prompt_text, modality, window_type, brightness, contrast, colormap, transparency, progress ) download_file = result if result else None return result, status, metrics, download_file submit_enh_btn.click( fn=process_enhanced_wrapper, inputs=[ file_input_enh, text_input_enh, modality_enh, window_enh, brightness_slider, contrast_slider, colormap_dropdown, transparency_slider ], outputs=[image_output_enh, progress_text, metrics_text, download_output] ) # Enhanced batch processing submit_enh_batch_btn.click( fn=process_batch_enhanced, inputs=[ files_input_enh_batch, text_input_enh_batch, modality_enh_batch, window_enh_batch, brightness_slider_batch, contrast_slider_batch, colormap_dropdown_batch, transparency_slider_batch ], outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text] ) # Point prompt processing submit_point_btn.click( fn=process_with_point_prompt, inputs=[file_input_point, point_x, point_y, modality_point, window_point, colormap_point, transparency_point], outputs=[output_point, status_point] ) # Box prompt processing submit_box_btn.click( fn=process_with_box_prompt, inputs=[file_input_box, box_x1, box_y1, box_x2, box_y2, modality_box, window_box, colormap_box, transparency_box], outputs=[output_box, status_box] ) # ROI Statistics processing submit_stats_btn.click( fn=process_and_store_mask, inputs=[file_input_stats, text_input_stats, modality_stats, window_stats], outputs=[output_stats, status_stats, roi_stats_text] ) # NIFTI Export export_nifti_btn.click( fn=export_last_mask_nifti, inputs=[], outputs=[nifti_download, status_stats] ) # Save Annotation save_annotation_btn.click( fn=save_last_annotation, inputs=[], outputs=[annotation_download, status_stats] ) # Load Annotation load_annotation_btn.click( fn=visualize_loaded_annotation, inputs=[original_image_upload, annotation_upload], outputs=[loaded_annotation_output, loaded_annotation_info] ) # Multi-Mask processing submit_multi_btn.click( fn=process_multi_mask, inputs=[file_input_multi, text_input_multi, modality_multi, window_multi, num_masks_slider], outputs=[gallery_multi, status_multi, mask_info_multi] ) # Auto-play functionality for slice viewer def auto_play_slices(files, selected_subject, prompt, mod, window): """Auto-play through slices with a short delay.""" if not files: yield None, "No slices loaded", 0 return subject_groups = group_images_by_subject(files) if selected_subject: subject_id = selected_subject.split(" (")[0] else: subject_id = list(subject_groups.keys())[0] if subject_groups else None if not subject_id or subject_id not in subject_groups: yield None, "No slices loaded", 0 return subject_files = subject_groups[subject_id]['files'] cache_key = f"{subject_id}_{len(subject_files)}_{prompt}_{mod}" if cache_key not in processed_results_cache: yield None, "Please process slices first", 0 return results = processed_results_cache[cache_key] for idx in range(len(results)): slice_info = f"Slice {idx + 1}/{len(results)} ({subject_id}) - Auto-playing..." yield results[idx], slice_info, idx time.sleep(0.5) # 500ms delay between slices auto_play_btn.click( fn=auto_play_slices, inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], outputs=[current_slice_output, slice_info_text, slice_slider] ) # SAM-Medical-Imaging Inspired Features # Automatic Mask Generator submit_amg_btn.click( fn=automatic_mask_generator, inputs=[file_input_amg, modality_amg, window_amg, points_per_side, min_mask_area, colormap_amg], outputs=[output_amg, status_amg, info_amg] ) # Advanced Transforms submit_transform_btn.click( fn=process_with_advanced_transforms, inputs=[file_input_transform, text_input_transform, modality_transform, window_transform, target_size_slider, apply_clahe_checkbox, clahe_clip_slider, colormap_transform, transparency_transform], outputs=[output_transform, status_transform] ) # Edge-Based Segmentation submit_edge_btn.click( fn=edge_based_segmentation, inputs=[file_input_edge, modality_edge, window_edge, edge_threshold_slider, dilation_size_slider, colormap_edge, transparency_edge], outputs=[output_edge, status_edge] ) if __name__ == "__main__": # Verify model is loaded before launching if not is_model_loaded(): logger.warning("SAM 3 model failed to load!") logger.warning("The app will start but segmentation features will not work.") logger.warning("Please check:") logger.warning(" 1. HF_TOKEN environment variable is set correctly") logger.warning(" 2. transformers>=4.45.0 is installed") logger.warning(" 3. Sufficient memory/GPU available") else: logger.info("SAM 3 model ready - app starting...") demo.launch(server_name="0.0.0.0", mcp_server=True, server_port=7860)