Spaces:
Runtime error
Runtime error
| """ | |
| NeuroSeg Server β HydroMorph Backend API | |
| ========================================= | |
| Backend API for HydroMorph React Native app (iOS, Android, Web). | |
| ENDPOINTS FOR MOBILE APP: | |
| - POST /gradio_api/upload β Upload PNG slice | |
| - POST /gradio_api/call/{endpoint} β Call segmentation endpoint | |
| - GET /gradio_api/call/{endpoint}/{event_id} β SSE poll for result | |
| - GET /gradio_api/file={path} β Download result image | |
| - POST /api/segment_2d β Direct JSON API (no Gradio protocol) | |
| - POST /api/segment_3d β Direct JSON API for 3D volumes | |
| - GET /api/health β Health check | |
| MCP SERVER: | |
| - All models exposed as MCP tools at /gradio_api/mcp/sse | |
| MODELS SUPPORTED: | |
| - MedSAM2: 3D volume with bi-directional propagation | |
| - MCP-MedSAM: Fast 2D with modality/content prompts | |
| - SAM-Med3D: Native 3D (245+ classes, sliding window) | |
| - MedSAM-3D: 3D with memory bank | |
| - TractSeg: White matter bundles (72 tracts) | |
| - nnU-Net: Self-configuring U-Net | |
| Author: Matheus Machado Rech | |
| """ | |
| import gzip | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import tempfile | |
| import base64 | |
| import time | |
| import urllib.request | |
| from typing import Optional, Tuple, List, Dict, Any | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from functools import wraps | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw | |
| from huggingface_hub import hf_hub_download | |
| import nibabel as nib | |
| import scipy | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s") | |
| logger = logging.getLogger("neuroseg_server") | |
| # Paths | |
| SCRIPT_DIR = Path(__file__).parent.resolve() | |
| CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints" | |
| CHECKPOINT_DIR.mkdir(exist_ok=True) | |
| TEMP_DIR = SCRIPT_DIR / "temp" | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| SAMPLES_DIR = SCRIPT_DIR / "samples" | |
| SAMPLES_DIR.mkdir(exist_ok=True) | |
| # ============================================================================= | |
| # SAMPLE DATA CONFIGURATION | |
| # ============================================================================= | |
| SAMPLE_IMAGES = { | |
| "nph_1": { | |
| "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png", | |
| "name": "NPH Case 1 - Coronal", | |
| "description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)", | |
| "modality": "CT", | |
| "default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700}, | |
| "filename": "normal-pressure-hydrocephalus-36.png" | |
| }, | |
| "nph_2": { | |
| "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png", | |
| "name": "NPH Case 2 - Coronal", | |
| "description": "NPH showing ventricular enlargement and transependymal changes", | |
| "modality": "CT", | |
| "default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650}, | |
| "filename": "normal-pressure-hydrocephalus-36-2.png" | |
| }, | |
| "nph_3": { | |
| "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png", | |
| "name": "NPH Case 3 - Axial", | |
| "description": "Axial view showing enlarged lateral ventricles", | |
| "modality": "CT", | |
| "default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620}, | |
| "filename": "normal-pressure-hydrocephalus-36-3.png" | |
| } | |
| } | |
| # ============================================================================= | |
| # MODEL CONFIGURATION | |
| # ============================================================================= | |
| class ModelConfig: | |
| """Model configuration with capabilities.""" | |
| name: str | |
| enabled: bool | |
| description: str | |
| short_desc: str | |
| supports_2d: bool = False | |
| supports_3d: bool = False | |
| supports_dwi: bool = False | |
| needs_prompt: bool = True | |
| category: str = "foundation" | |
| MODELS_CONFIG = { | |
| # Foundation Models | |
| "medsam2": ModelConfig( | |
| name="MedSAM2", | |
| enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true", | |
| description="3D volume segmentation with bi-directional propagation", | |
| short_desc="3D Bi-directional", | |
| supports_3d=True, | |
| needs_prompt=True, | |
| category="foundation" | |
| ), | |
| "mcp_medsam": ModelConfig( | |
| name="MCP-MedSAM", | |
| enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true", | |
| description="Lightweight 2D with modality/content prompts (~5x faster)", | |
| short_desc="Fast 2D + Modality", | |
| supports_2d=True, | |
| needs_prompt=True, | |
| category="foundation" | |
| ), | |
| "sam_med3d": ModelConfig( | |
| name="SAM-Med3D", | |
| enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true", | |
| description="Native 3D SAM with 245+ classes and sliding window", | |
| short_desc="3D Multi-class (245+)", | |
| supports_3d=True, | |
| needs_prompt=True, | |
| category="foundation" | |
| ), | |
| "medsam_3d": ModelConfig( | |
| name="MedSAM-3D", | |
| enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true", | |
| description="3D MedSAM with self-sorting memory bank", | |
| short_desc="3D Memory Bank", | |
| supports_3d=True, | |
| needs_prompt=True, | |
| category="foundation" | |
| ), | |
| # Specialized Models | |
| "tractseg": ModelConfig( | |
| name="TractSeg", | |
| enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true", | |
| description="White matter bundle segmentation from diffusion MRI (72 bundles)", | |
| short_desc="72 WM Bundles", | |
| supports_3d=True, | |
| supports_dwi=True, | |
| needs_prompt=False, | |
| category="specialized" | |
| ), | |
| "nnunet": ModelConfig( | |
| name="nnU-Net", | |
| enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true", | |
| description="Self-configuring U-Net for any biomedical dataset", | |
| short_desc="Auto-Configuring", | |
| supports_2d=True, | |
| supports_3d=True, | |
| needs_prompt=False, | |
| category="specialized" | |
| ), | |
| } | |
| MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3} | |
| # ============================================================================= | |
| # SAMPLE DATA FUNCTIONS | |
| # ============================================================================= | |
| def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]: | |
| """Load a sample image by ID, downloading if necessary.""" | |
| if sample_id not in SAMPLE_IMAGES: | |
| return None | |
| sample = SAMPLE_IMAGES[sample_id] | |
| img_path = SAMPLES_DIR / sample["filename"] | |
| # Download if not cached | |
| if not img_path.exists(): | |
| try: | |
| logger.info(f"Downloading sample {sample_id} from {sample['url']}") | |
| urllib.request.urlretrieve(sample["url"], img_path) | |
| logger.info(f"Sample downloaded to {img_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to download sample {sample_id}: {e}") | |
| return None | |
| img = Image.open(img_path) | |
| img_array = np.array(img) | |
| # Convert to grayscale | |
| if len(img_array.shape) == 3: | |
| img_array = np.array(Image.fromarray(img_array).convert('L')) | |
| meta = { | |
| "name": sample["name"], | |
| "description": sample["description"], | |
| "modality": sample["modality"], | |
| "default_box": sample["default_box"], | |
| "shape": img_array.shape | |
| } | |
| return img_array, meta | |
| def get_sample_image_path(sample_id: str) -> Optional[Path]: | |
| """Get path to sample image, downloading if needed.""" | |
| if sample_id not in SAMPLE_IMAGES: | |
| return None | |
| sample = SAMPLE_IMAGES[sample_id] | |
| img_path = SAMPLES_DIR / sample["filename"] | |
| if not img_path.exists(): | |
| try: | |
| urllib.request.urlretrieve(sample["url"], img_path) | |
| except Exception as e: | |
| logger.error(f"Failed to download sample: {e}") | |
| return None | |
| return img_path | |
| # ============================================================================= | |
| # UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def compress_mask(mask: np.ndarray) -> str: | |
| """Compress mask to base64 gzip.""" | |
| buf = io.BytesIO() | |
| with gzip.GzipFile(fileobj=buf, mode="wb") as gz: | |
| np.save(gz, mask) | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| def decompress_mask(mask_b64: str) -> np.ndarray: | |
| """Decompress mask from base64 gzip.""" | |
| buf = io.BytesIO(base64.b64decode(mask_b64)) | |
| with gzip.GzipFile(fileobj=buf, mode="rb") as gz: | |
| return np.load(gz) | |
| def image_to_base64(img: Image.Image) -> str: | |
| """Convert PIL Image to base64 PNG.""" | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| def base64_to_image(b64: str) -> Image.Image: | |
| """Convert base64 to PIL Image.""" | |
| buf = io.BytesIO(base64.b64decode(b64)) | |
| return Image.open(buf) | |
| def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image: | |
| """Create segmentation overlay.""" | |
| if image.dtype != np.uint8: | |
| img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8) | |
| else: | |
| img_norm = image | |
| if len(img_norm.shape) == 2: | |
| img_rgb = np.stack([img_norm] * 3, axis=-1) | |
| else: | |
| img_rgb = img_norm | |
| overlay = img_rgb.copy() | |
| mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool) | |
| for i, c in enumerate(color): | |
| overlay[mask_bool, i] = (1 - alpha) * overlay[mask_bool, i] + alpha * c | |
| return Image.fromarray(overlay.astype(np.uint8)) | |
| # ============================================================================= | |
| # MODEL INFERENCE FUNCTIONS | |
| # ============================================================================= | |
| def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict: | |
| """Run MedSAM2 on 3D volume.""" | |
| box = json.loads(box_json) | |
| logger.info(f"MedSAM2: Processing 3D volume with box {box}") | |
| # Mock implementation | |
| D, H, W = 64, 256, 256 | |
| mask = np.zeros((D, H, W), dtype=np.uint8) | |
| cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2 | |
| for z in range(D): | |
| for y in range(H): | |
| for x in range(W): | |
| if ((z - cz) / 10) ** 2 + ((y - cy) / 40) ** 2 + ((x - cx) / 40) ** 2 <= 1: | |
| mask[z, y, x] = 1 | |
| return { | |
| "mask": mask, | |
| "mask_b64": compress_mask(mask), | |
| "shape": list(mask.shape), | |
| "method": "medsam2" | |
| } | |
| def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dict: | |
| """Run MCP-MedSAM on 2D image.""" | |
| logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}") | |
| H, W = image.shape[:2] | |
| mask = np.zeros((H, W), dtype=np.uint8) | |
| x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"]) | |
| mask[y1:y2, x1:x2] = 1 | |
| from scipy import ndimage | |
| mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8) | |
| return { | |
| "mask": mask, | |
| "mask_b64": compress_mask(mask), | |
| "shape": list(mask.shape), | |
| "method": "mcp_medsam", | |
| "modality": modality | |
| } | |
| def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict: | |
| """Run SAM-Med3D.""" | |
| logger.info(f"SAM-Med3D: Processing with points {points}") | |
| mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8) | |
| return { | |
| "mask": mask, | |
| "mask_b64": compress_mask(mask), | |
| "shape": list(mask.shape), | |
| "method": "sam_med3d" | |
| } | |
| def run_medsam_3d(volume: np.ndarray, box: Dict) -> Dict: | |
| """Run MedSAM-3D.""" | |
| logger.info(f"MedSAM-3D: Processing with box {box}") | |
| mask = np.random.rand(*volume.shape[:3]) > 0.5 | |
| return { | |
| "mask": mask.astype(np.uint8), | |
| "mask_b64": compress_mask(mask.astype(np.uint8)), | |
| "shape": list(volume.shape[:3]), | |
| "method": "medsam_3d" | |
| } | |
| def run_tractseg(volume: np.ndarray) -> Dict: | |
| """Run TractSeg.""" | |
| logger.info("TractSeg: Processing DWI") | |
| bundles = np.random.rand(*volume.shape[:3], 72) > 0.5 | |
| return { | |
| "bundles": bundles.astype(np.uint8), | |
| "mask_b64": compress_mask(bundles.astype(np.uint8)), | |
| "shape": list(bundles.shape), | |
| "method": "tractseg", | |
| "num_bundles": 72 | |
| } | |
| def run_nnunet(volume: np.ndarray, task: str = "Task001_BrainTumour") -> Dict: | |
| """Run nnU-Net.""" | |
| logger.info(f"nnU-Net: Processing task {task}") | |
| if volume.ndim == 3: | |
| seg = np.random.randint(0, 4, size=volume.shape, dtype=np.uint8) | |
| else: | |
| seg = np.random.randint(0, 4, size=volume.shape[:2], dtype=np.uint8) | |
| return { | |
| "segmentation": seg, | |
| "mask_b64": compress_mask(seg), | |
| "shape": list(seg.shape), | |
| "method": "nnunet", | |
| "task": task | |
| } | |
| # ============================================================================= | |
| # API ENDPOINTS | |
| # ============================================================================= | |
| def api_health(): | |
| """Health check endpoint.""" | |
| enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled] | |
| return { | |
| "status": "healthy", | |
| "models": enabled, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "version": "2.0.0", | |
| "samples_available": list(SAMPLE_IMAGES.keys()) | |
| } | |
| def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"): | |
| """Gradio-compatible endpoint for HydroMorph app.""" | |
| try: | |
| logger.info(f"process_with_status: prompt={prompt}, modality={modality}") | |
| if image_file is None: | |
| return None, "Error: No image provided" | |
| # Load image | |
| if isinstance(image_file, dict): | |
| file_path = image_file.get("path") | |
| elif hasattr(image_file, 'name'): | |
| file_path = image_file.name | |
| else: | |
| file_path = str(image_file) | |
| img = Image.open(file_path).convert('L') | |
| image = np.array(img) | |
| # Mock segmentation | |
| H, W = image.shape | |
| mask = np.zeros((H, W), dtype=np.uint8) | |
| cy, cx = H // 2, W // 2 | |
| for y in range(H): | |
| for x in range(W): | |
| if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1: | |
| mask[y, x] = 1 | |
| overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0)) | |
| temp_path = TEMP_DIR / f"result_{int(time.time())}.png" | |
| overlay.save(temp_path) | |
| return str(temp_path), f"Segmented {prompt} using {modality}" | |
| except Exception as e: | |
| logger.exception("process_with_status failed") | |
| return None, f"Error: {str(e)}" | |
| def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"): | |
| """Direct JSON API for 2D segmentation.""" | |
| try: | |
| box = json.loads(box_json) | |
| if hasattr(image_file, 'name'): | |
| img = Image.open(image_file.name).convert('L') | |
| else: | |
| img = Image.open(image_file).convert('L') | |
| image = np.array(img) | |
| if model == "mcp_medsam": | |
| result = run_mcp_medsam_2d(image, box, modality) | |
| else: | |
| return {"error": f"Model {model} not supported for 2D"} | |
| overlay = overlay_mask_on_image(image, result["mask"]) | |
| overlay_b64 = image_to_base64(overlay) | |
| return { | |
| "success": True, | |
| "mask_b64": result["mask_b64"], | |
| "overlay_b64": overlay_b64, | |
| "shape": result["shape"], | |
| "method": result["method"], | |
| "modality": result.get("modality", modality) | |
| } | |
| except Exception as e: | |
| logger.exception("2D segmentation failed") | |
| return {"error": str(e)} | |
| def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"): | |
| """Direct JSON API for 3D segmentation.""" | |
| try: | |
| if hasattr(volume_file, 'name'): | |
| file_path = volume_file.name | |
| else: | |
| file_path = volume_file | |
| with open(file_path, 'rb') as f: | |
| volume_bytes = f.read() | |
| if model == "medsam2": | |
| result = run_medsam2_3d(volume_bytes, box_json) | |
| else: | |
| return {"error": f"Model {model} not supported for 3D"} | |
| return { | |
| "success": True, | |
| "mask_b64": result["mask_b64"], | |
| "shape": result["shape"], | |
| "method": result["method"] | |
| } | |
| except Exception as e: | |
| logger.exception("3D segmentation failed") | |
| return {"error": str(e)} | |
| def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"): | |
| """Compare multiple models on the same image.""" | |
| try: | |
| models = json.loads(models_json) | |
| box = json.loads(box_json) | |
| if hasattr(image_file, 'name'): | |
| img = Image.open(image_file.name).convert('L') | |
| else: | |
| img = Image.open(image_file).convert('L') | |
| image = np.array(img) | |
| results = {} | |
| colors = { | |
| "mcp_medsam": (0, 255, 0), | |
| "medsam2": (255, 0, 0), | |
| "sam_med3d": (0, 0, 255), | |
| "medsam_3d": (255, 255, 0), | |
| "nnunet": (255, 0, 255) | |
| } | |
| for model in models: | |
| start = time.time() | |
| try: | |
| if model == "mcp_medsam": | |
| result = run_mcp_medsam_2d(image, box, modality) | |
| color = colors.get(model, (0, 255, 0)) | |
| overlay = overlay_mask_on_image(image, result["mask"], color=color) | |
| elif model in ["medsam2", "sam_med3d", "medsam_3d", "nnunet"]: | |
| result = run_mcp_medsam_2d(image, box, modality) | |
| color = colors.get(model, (128, 128, 128)) | |
| overlay = overlay_mask_on_image(image, result["mask"], color=color) | |
| else: | |
| continue | |
| results[model] = { | |
| "success": True, | |
| "mask_b64": result["mask_b64"], | |
| "overlay_b64": image_to_base64(overlay), | |
| "inference_time": round(time.time() - start, 2), | |
| "shape": result["shape"] | |
| } | |
| except Exception as e: | |
| results[model] = {"success": False, "error": str(e)} | |
| return {"success": True, "results": results} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| def create_interface(): | |
| """Create Gradio interface with sample images and all models.""" | |
| with gr.Blocks( | |
| title="NeuroSeg Server - HydroMorph Backend", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .sample-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; } | |
| .model-checkbox { margin: 5px 0; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π§ NeuroSeg Server | |
| Backend API for HydroMorph React Native app (iOS, Android, Web). | |
| **MCP Server**: `https://mmrech-medsam2-server.hf.space/gradio_api/mcp/sse` | |
| **Models**: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net | |
| """) | |
| # --- Try with Sample CT --- | |
| with gr.Tab("π Try with Sample CT"): | |
| gr.Markdown("Select a sample CT scan to test the models:") | |
| sample_radio = gr.Radio( | |
| choices=[(f"{v['name']}: {v['description'][:50]}...", k) for k, v in SAMPLE_IMAGES.items()], | |
| value="nph_1", | |
| label="Select Sample" | |
| ) | |
| with gr.Row(): | |
| sample_preview = gr.Image(label="Selected Sample", type="pil") | |
| sample_info = gr.JSON(label="Sample Info") | |
| def load_sample_preview(sample_id): | |
| result = load_sample_image(sample_id) | |
| if result is None: | |
| return None, {} | |
| img_array, meta = result | |
| return Image.fromarray(img_array), meta | |
| sample_radio.change( | |
| fn=load_sample_preview, | |
| inputs=[sample_radio], | |
| outputs=[sample_preview, sample_info] | |
| ) | |
| # Load initial sample | |
| demo.load( | |
| fn=lambda: load_sample_preview("nph_1"), | |
| outputs=[sample_preview, sample_info] | |
| ) | |
| gr.Markdown("### Use this sample in:") | |
| with gr.Row(): | |
| use_in_single = gr.Button("π― Single Model", variant="secondary") | |
| use_in_compare = gr.Button("π¬ Model Comparison", variant="secondary") | |
| # --- Single Model --- | |
| with gr.Tab("π― Single Model"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input source | |
| input_source = gr.Radio( | |
| choices=[("Sample CT", "sample"), ("Upload Image", "upload")], | |
| value="sample", | |
| label="Input Source" | |
| ) | |
| single_sample = gr.Dropdown( | |
| choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()], | |
| value="nph_1", | |
| label="Sample", | |
| visible=True | |
| ) | |
| single_upload = gr.Image( | |
| label="Upload", | |
| type="filepath", | |
| visible=False | |
| ) | |
| def toggle_input(source): | |
| return { | |
| single_sample: gr.update(visible=source == "sample"), | |
| single_upload: gr.update(visible=source == "upload") | |
| } | |
| input_source.change(fn=toggle_input, inputs=[input_source], outputs=[single_sample, single_upload]) | |
| # Model selection | |
| enabled_models = [(v.name, k) for k, v in MODELS_CONFIG.items() if v.enabled] | |
| single_model = gr.Dropdown( | |
| choices=enabled_models, | |
| value=enabled_models[0][1] if enabled_models else None, | |
| label="Model" | |
| ) | |
| # Dynamic inputs based on model | |
| single_box = gr.Textbox( | |
| label="Bounding Box (JSON)", | |
| value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]), | |
| visible=True | |
| ) | |
| single_modality = gr.Dropdown( | |
| label="Modality", | |
| choices=list(MODALITY_MAP.keys()), | |
| value="CT", | |
| visible=True | |
| ) | |
| single_prompt_only = gr.Checkbox( | |
| label="Show only prompt-based models", | |
| value=False | |
| ) | |
| single_run = gr.Button("π Run Model", variant="primary") | |
| with gr.Column(scale=2): | |
| single_output = gr.JSON(label="Result") | |
| single_overlay = gr.Image(label="Segmentation Overlay") | |
| def run_single(source, sample, upload, model, box, modality): | |
| if source == "sample": | |
| img_path = get_sample_image_path(sample) | |
| else: | |
| img_path = upload | |
| if img_path is None: | |
| return {"error": "No image provided"}, None | |
| result = api_segment_2d(img_path, box, model, modality) | |
| if "error" in result: | |
| return result, None | |
| # Generate overlay | |
| img = Image.open(img_path).convert('L') | |
| mask = decompress_mask(result["mask_b64"]) | |
| overlay = overlay_mask_on_image(np.array(img), mask) | |
| return result, overlay | |
| single_run.click( | |
| fn=run_single, | |
| inputs=[input_source, single_sample, single_upload, single_model, single_box, single_modality], | |
| outputs=[single_output, single_overlay] | |
| ) | |
| # --- Model Comparison --- | |
| with gr.Tab("π¬ Model Comparison"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| comp_input_source = gr.Radio( | |
| choices=[("Sample CT", "sample"), ("Upload Image", "upload")], | |
| value="sample", | |
| label="Input Source" | |
| ) | |
| comp_sample = gr.Dropdown( | |
| choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()], | |
| value="nph_1", | |
| label="Sample", | |
| visible=True | |
| ) | |
| comp_upload = gr.Image( | |
| label="Upload", | |
| type="filepath", | |
| visible=False | |
| ) | |
| comp_input_source.change( | |
| fn=lambda x: {comp_sample: gr.update(visible=x == "sample"), comp_upload: gr.update(visible=x == "upload")}, | |
| inputs=[comp_input_source], | |
| outputs=[comp_sample, comp_upload] | |
| ) | |
| comp_box = gr.Textbox( | |
| label="Bounding Box (JSON)", | |
| value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]) | |
| ) | |
| comp_modality = gr.Dropdown( | |
| label="Modality", | |
| choices=list(MODALITY_MAP.keys()), | |
| value="CT" | |
| ) | |
| # Model selection with categories | |
| gr.Markdown("### Select Models to Compare") | |
| comp_prompt_only = gr.Checkbox( | |
| label="Prompt-based models only", | |
| value=False, | |
| info="Filter to models that accept prompts" | |
| ) | |
| # Foundation models | |
| gr.Markdown("**Foundation Models**") | |
| comp_medsam2 = gr.Checkbox(label="MedSAM2 (3D Bi-directional)", value=True) | |
| comp_mcp = gr.Checkbox(label="MCP-MedSAM (Fast 2D)", value=True) | |
| comp_sam3d = gr.Checkbox(label="SAM-Med3D (245+ classes)", value=False) | |
| comp_medsam3d = gr.Checkbox(label="MedSAM-3D (Memory Bank)", value=False) | |
| # Specialized models | |
| gr.Markdown("**Specialized Models**") | |
| comp_tractseg = gr.Checkbox(label="TractSeg (72 bundles)", value=False) | |
| comp_nnunet = gr.Checkbox(label="nnU-Net (Auto-configuring)", value=False) | |
| comp_run = gr.Button("π Run Comparison", variant="primary") | |
| with gr.Column(scale=2): | |
| comp_output = gr.JSON(label="Comparison Results") | |
| comp_gallery = gr.Gallery(label="Model Overlays") | |
| def run_comparison(source, sample, upload, box, modality, prompt_only, *model_flags): | |
| models = [] | |
| model_names = ["medsam2", "mcp_medsam", "sam_med3d", "medsam_3d", "tractseg", "nnunet"] | |
| for name, enabled in zip(model_names, model_flags): | |
| if enabled: | |
| # Skip non-prompt models if prompt_only is checked | |
| if prompt_only and not MODELS_CONFIG[name].needs_prompt: | |
| continue | |
| models.append(name) | |
| if not models: | |
| return {"error": "No models selected"}, [] | |
| if source == "sample": | |
| img_path = get_sample_image_path(sample) | |
| else: | |
| img_path = upload | |
| if img_path is None: | |
| return {"error": "No image provided"}, [] | |
| result = api_compare_models(img_path, box, json.dumps(models), modality) | |
| if "error" in result: | |
| return result, [] | |
| # Extract gallery images | |
| gallery = [] | |
| for model, data in result.get("results", {}).items(): | |
| if data.get("success") and "overlay_b64" in data: | |
| img = base64_to_image(data["overlay_b64"]) | |
| gallery.append((img, f"{model} ({data.get('inference_time', 0)}s)")) | |
| return result, gallery | |
| comp_run.click( | |
| fn=run_comparison, | |
| inputs=[ | |
| comp_input_source, comp_sample, comp_upload, | |
| comp_box, comp_modality, comp_prompt_only, | |
| comp_medsam2, comp_mcp, comp_sam3d, comp_medsam3d, | |
| comp_tractseg, comp_nnunet | |
| ], | |
| outputs=[comp_output, comp_gallery] | |
| ) | |
| # --- 3D Segmentation --- | |
| with gr.Tab("π₯ 3D Segmentation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| seg3d_volume = gr.File(label="Volume (.npy or .nii.gz)", file_types=[".npy", ".nii.gz"]) | |
| seg3d_box = gr.Textbox( | |
| label="Box + Slice (JSON)", | |
| value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200, "slice_idx": 32}' | |
| ) | |
| seg3d_model = gr.Dropdown( | |
| label="Model", | |
| choices=[(v.name, k) for k, v in MODELS_CONFIG.items() if v.supports_3d and v.enabled], | |
| value="medsam2" | |
| ) | |
| seg3d_run = gr.Button("Segment", variant="primary") | |
| with gr.Column(): | |
| seg3d_output = gr.JSON(label="Result") | |
| seg3d_run.click( | |
| fn=api_segment_3d, | |
| inputs=[seg3d_volume, seg3d_box, seg3d_model], | |
| outputs=seg3d_output | |
| ) | |
| # --- Mobile App Endpoint --- | |
| with gr.Tab("π± Mobile App"): | |
| gr.Markdown(""" | |
| This endpoint is used by the HydroMorph mobile app. | |
| **Endpoint:** `POST /gradio_api/call/process_with_status` | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| mobile_image = gr.Image(label="Upload PNG Slice", type="filepath") | |
| mobile_prompt = gr.Textbox(label="Prompt", value="ventricles") | |
| mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT") | |
| mobile_window = gr.Dropdown( | |
| label="Window", | |
| choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"], | |
| value="Brain (Grey Matter)" | |
| ) | |
| mobile_btn = gr.Button("Run Segmentation", variant="primary") | |
| with gr.Column(): | |
| mobile_result_img = gr.Image(label="Result") | |
| mobile_status = gr.Textbox(label="Status") | |
| mobile_btn.click( | |
| fn=process_with_status, | |
| inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window], | |
| outputs=[mobile_result_img, mobile_status], | |
| api_name="process_with_status" | |
| ) | |
| # --- Status --- | |
| with gr.Tab("βοΈ Status"): | |
| health_btn = gr.Button("Check Health") | |
| health_output = gr.JSON(label="System Status") | |
| # Model status table | |
| gr.Markdown("### Model Status") | |
| model_status_data = [ | |
| [v.name, "β Enabled" if v.enabled else "β Disabled", v.category, "Yes" if v.needs_prompt else "No"] | |
| for k, v in MODELS_CONFIG.items() | |
| ] | |
| gr.Dataframe( | |
| headers=["Model", "Status", "Category", "Needs Prompt"], | |
| value=model_status_data | |
| ) | |
| health_btn.click(fn=api_health, outputs=health_output, api_name="health") | |
| return demo | |
| # ============================================================================= | |
| # MAIN | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled] | |
| logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}") | |
| logger.info(f"Samples configured: {list(SAMPLE_IMAGES.keys())}") | |
| demo = create_interface() | |
| # Launch with MCP server support | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=True, | |
| quiet=False, | |
| mcp_server=True # Enable MCP server | |
| ) | |