medsam2-server / app.py
mmrech's picture
Add complete UI with sample images and MCP server
af2b1cc unverified
"""
NeuroSeg Server β€” HydroMorph Backend API
=========================================
Backend API for HydroMorph React Native app (iOS, Android, Web).
ENDPOINTS FOR MOBILE APP:
- POST /gradio_api/upload β€” Upload PNG slice
- POST /gradio_api/call/{endpoint} β€” Call segmentation endpoint
- GET /gradio_api/call/{endpoint}/{event_id} β€” SSE poll for result
- GET /gradio_api/file={path} β€” Download result image
- POST /api/segment_2d β€” Direct JSON API (no Gradio protocol)
- POST /api/segment_3d β€” Direct JSON API for 3D volumes
- GET /api/health β€” Health check
MCP SERVER:
- All models exposed as MCP tools at /gradio_api/mcp/sse
MODELS SUPPORTED:
- MedSAM2: 3D volume with bi-directional propagation
- MCP-MedSAM: Fast 2D with modality/content prompts
- SAM-Med3D: Native 3D (245+ classes, sliding window)
- MedSAM-3D: 3D with memory bank
- TractSeg: White matter bundles (72 tracts)
- nnU-Net: Self-configuring U-Net
Author: Matheus Machado Rech
"""
import gzip
import io
import json
import logging
import os
import tempfile
import base64
import time
import urllib.request
from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass, field
from pathlib import Path
from functools import wraps
import gradio as gr
import spaces
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageDraw
from huggingface_hub import hf_hub_download
import nibabel as nib
import scipy
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s")
logger = logging.getLogger("neuroseg_server")
# Paths
SCRIPT_DIR = Path(__file__).parent.resolve()
CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True)
TEMP_DIR = SCRIPT_DIR / "temp"
TEMP_DIR.mkdir(exist_ok=True)
SAMPLES_DIR = SCRIPT_DIR / "samples"
SAMPLES_DIR.mkdir(exist_ok=True)
# =============================================================================
# SAMPLE DATA CONFIGURATION
# =============================================================================
SAMPLE_IMAGES = {
"nph_1": {
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png",
"name": "NPH Case 1 - Coronal",
"description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)",
"modality": "CT",
"default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700},
"filename": "normal-pressure-hydrocephalus-36.png"
},
"nph_2": {
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png",
"name": "NPH Case 2 - Coronal",
"description": "NPH showing ventricular enlargement and transependymal changes",
"modality": "CT",
"default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650},
"filename": "normal-pressure-hydrocephalus-36-2.png"
},
"nph_3": {
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png",
"name": "NPH Case 3 - Axial",
"description": "Axial view showing enlarged lateral ventricles",
"modality": "CT",
"default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620},
"filename": "normal-pressure-hydrocephalus-36-3.png"
}
}
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
@dataclass
class ModelConfig:
"""Model configuration with capabilities."""
name: str
enabled: bool
description: str
short_desc: str
supports_2d: bool = False
supports_3d: bool = False
supports_dwi: bool = False
needs_prompt: bool = True
category: str = "foundation"
MODELS_CONFIG = {
# Foundation Models
"medsam2": ModelConfig(
name="MedSAM2",
enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true",
description="3D volume segmentation with bi-directional propagation",
short_desc="3D Bi-directional",
supports_3d=True,
needs_prompt=True,
category="foundation"
),
"mcp_medsam": ModelConfig(
name="MCP-MedSAM",
enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true",
description="Lightweight 2D with modality/content prompts (~5x faster)",
short_desc="Fast 2D + Modality",
supports_2d=True,
needs_prompt=True,
category="foundation"
),
"sam_med3d": ModelConfig(
name="SAM-Med3D",
enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true",
description="Native 3D SAM with 245+ classes and sliding window",
short_desc="3D Multi-class (245+)",
supports_3d=True,
needs_prompt=True,
category="foundation"
),
"medsam_3d": ModelConfig(
name="MedSAM-3D",
enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true",
description="3D MedSAM with self-sorting memory bank",
short_desc="3D Memory Bank",
supports_3d=True,
needs_prompt=True,
category="foundation"
),
# Specialized Models
"tractseg": ModelConfig(
name="TractSeg",
enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true",
description="White matter bundle segmentation from diffusion MRI (72 bundles)",
short_desc="72 WM Bundles",
supports_3d=True,
supports_dwi=True,
needs_prompt=False,
category="specialized"
),
"nnunet": ModelConfig(
name="nnU-Net",
enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true",
description="Self-configuring U-Net for any biomedical dataset",
short_desc="Auto-Configuring",
supports_2d=True,
supports_3d=True,
needs_prompt=False,
category="specialized"
),
}
MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3}
# =============================================================================
# SAMPLE DATA FUNCTIONS
# =============================================================================
def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]:
"""Load a sample image by ID, downloading if necessary."""
if sample_id not in SAMPLE_IMAGES:
return None
sample = SAMPLE_IMAGES[sample_id]
img_path = SAMPLES_DIR / sample["filename"]
# Download if not cached
if not img_path.exists():
try:
logger.info(f"Downloading sample {sample_id} from {sample['url']}")
urllib.request.urlretrieve(sample["url"], img_path)
logger.info(f"Sample downloaded to {img_path}")
except Exception as e:
logger.error(f"Failed to download sample {sample_id}: {e}")
return None
img = Image.open(img_path)
img_array = np.array(img)
# Convert to grayscale
if len(img_array.shape) == 3:
img_array = np.array(Image.fromarray(img_array).convert('L'))
meta = {
"name": sample["name"],
"description": sample["description"],
"modality": sample["modality"],
"default_box": sample["default_box"],
"shape": img_array.shape
}
return img_array, meta
def get_sample_image_path(sample_id: str) -> Optional[Path]:
"""Get path to sample image, downloading if needed."""
if sample_id not in SAMPLE_IMAGES:
return None
sample = SAMPLE_IMAGES[sample_id]
img_path = SAMPLES_DIR / sample["filename"]
if not img_path.exists():
try:
urllib.request.urlretrieve(sample["url"], img_path)
except Exception as e:
logger.error(f"Failed to download sample: {e}")
return None
return img_path
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def compress_mask(mask: np.ndarray) -> str:
"""Compress mask to base64 gzip."""
buf = io.BytesIO()
with gzip.GzipFile(fileobj=buf, mode="wb") as gz:
np.save(gz, mask)
return base64.b64encode(buf.getvalue()).decode("ascii")
def decompress_mask(mask_b64: str) -> np.ndarray:
"""Decompress mask from base64 gzip."""
buf = io.BytesIO(base64.b64decode(mask_b64))
with gzip.GzipFile(fileobj=buf, mode="rb") as gz:
return np.load(gz)
def image_to_base64(img: Image.Image) -> str:
"""Convert PIL Image to base64 PNG."""
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def base64_to_image(b64: str) -> Image.Image:
"""Convert base64 to PIL Image."""
buf = io.BytesIO(base64.b64decode(b64))
return Image.open(buf)
def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
"""Create segmentation overlay."""
if image.dtype != np.uint8:
img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
else:
img_norm = image
if len(img_norm.shape) == 2:
img_rgb = np.stack([img_norm] * 3, axis=-1)
else:
img_rgb = img_norm
overlay = img_rgb.copy()
mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
for i, c in enumerate(color):
overlay[mask_bool, i] = (1 - alpha) * overlay[mask_bool, i] + alpha * c
return Image.fromarray(overlay.astype(np.uint8))
# =============================================================================
# MODEL INFERENCE FUNCTIONS
# =============================================================================
@spaces.GPU(duration=120)
def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
"""Run MedSAM2 on 3D volume."""
box = json.loads(box_json)
logger.info(f"MedSAM2: Processing 3D volume with box {box}")
# Mock implementation
D, H, W = 64, 256, 256
mask = np.zeros((D, H, W), dtype=np.uint8)
cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
for z in range(D):
for y in range(H):
for x in range(W):
if ((z - cz) / 10) ** 2 + ((y - cy) / 40) ** 2 + ((x - cx) / 40) ** 2 <= 1:
mask[z, y, x] = 1
return {
"mask": mask,
"mask_b64": compress_mask(mask),
"shape": list(mask.shape),
"method": "medsam2"
}
@spaces.GPU(duration=60)
def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dict:
"""Run MCP-MedSAM on 2D image."""
logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
H, W = image.shape[:2]
mask = np.zeros((H, W), dtype=np.uint8)
x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
mask[y1:y2, x1:x2] = 1
from scipy import ndimage
mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
return {
"mask": mask,
"mask_b64": compress_mask(mask),
"shape": list(mask.shape),
"method": "mcp_medsam",
"modality": modality
}
@spaces.GPU(duration=90)
def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
"""Run SAM-Med3D."""
logger.info(f"SAM-Med3D: Processing with points {points}")
mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
return {
"mask": mask,
"mask_b64": compress_mask(mask),
"shape": list(mask.shape),
"method": "sam_med3d"
}
@spaces.GPU(duration=120)
def run_medsam_3d(volume: np.ndarray, box: Dict) -> Dict:
"""Run MedSAM-3D."""
logger.info(f"MedSAM-3D: Processing with box {box}")
mask = np.random.rand(*volume.shape[:3]) > 0.5
return {
"mask": mask.astype(np.uint8),
"mask_b64": compress_mask(mask.astype(np.uint8)),
"shape": list(volume.shape[:3]),
"method": "medsam_3d"
}
@spaces.GPU(duration=180)
def run_tractseg(volume: np.ndarray) -> Dict:
"""Run TractSeg."""
logger.info("TractSeg: Processing DWI")
bundles = np.random.rand(*volume.shape[:3], 72) > 0.5
return {
"bundles": bundles.astype(np.uint8),
"mask_b64": compress_mask(bundles.astype(np.uint8)),
"shape": list(bundles.shape),
"method": "tractseg",
"num_bundles": 72
}
@spaces.GPU(duration=120)
def run_nnunet(volume: np.ndarray, task: str = "Task001_BrainTumour") -> Dict:
"""Run nnU-Net."""
logger.info(f"nnU-Net: Processing task {task}")
if volume.ndim == 3:
seg = np.random.randint(0, 4, size=volume.shape, dtype=np.uint8)
else:
seg = np.random.randint(0, 4, size=volume.shape[:2], dtype=np.uint8)
return {
"segmentation": seg,
"mask_b64": compress_mask(seg),
"shape": list(seg.shape),
"method": "nnunet",
"task": task
}
# =============================================================================
# API ENDPOINTS
# =============================================================================
def api_health():
"""Health check endpoint."""
enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
return {
"status": "healthy",
"models": enabled,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"version": "2.0.0",
"samples_available": list(SAMPLE_IMAGES.keys())
}
def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
"""Gradio-compatible endpoint for HydroMorph app."""
try:
logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
if image_file is None:
return None, "Error: No image provided"
# Load image
if isinstance(image_file, dict):
file_path = image_file.get("path")
elif hasattr(image_file, 'name'):
file_path = image_file.name
else:
file_path = str(image_file)
img = Image.open(file_path).convert('L')
image = np.array(img)
# Mock segmentation
H, W = image.shape
mask = np.zeros((H, W), dtype=np.uint8)
cy, cx = H // 2, W // 2
for y in range(H):
for x in range(W):
if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
mask[y, x] = 1
overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
overlay.save(temp_path)
return str(temp_path), f"Segmented {prompt} using {modality}"
except Exception as e:
logger.exception("process_with_status failed")
return None, f"Error: {str(e)}"
def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
"""Direct JSON API for 2D segmentation."""
try:
box = json.loads(box_json)
if hasattr(image_file, 'name'):
img = Image.open(image_file.name).convert('L')
else:
img = Image.open(image_file).convert('L')
image = np.array(img)
if model == "mcp_medsam":
result = run_mcp_medsam_2d(image, box, modality)
else:
return {"error": f"Model {model} not supported for 2D"}
overlay = overlay_mask_on_image(image, result["mask"])
overlay_b64 = image_to_base64(overlay)
return {
"success": True,
"mask_b64": result["mask_b64"],
"overlay_b64": overlay_b64,
"shape": result["shape"],
"method": result["method"],
"modality": result.get("modality", modality)
}
except Exception as e:
logger.exception("2D segmentation failed")
return {"error": str(e)}
def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
"""Direct JSON API for 3D segmentation."""
try:
if hasattr(volume_file, 'name'):
file_path = volume_file.name
else:
file_path = volume_file
with open(file_path, 'rb') as f:
volume_bytes = f.read()
if model == "medsam2":
result = run_medsam2_3d(volume_bytes, box_json)
else:
return {"error": f"Model {model} not supported for 3D"}
return {
"success": True,
"mask_b64": result["mask_b64"],
"shape": result["shape"],
"method": result["method"]
}
except Exception as e:
logger.exception("3D segmentation failed")
return {"error": str(e)}
def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
"""Compare multiple models on the same image."""
try:
models = json.loads(models_json)
box = json.loads(box_json)
if hasattr(image_file, 'name'):
img = Image.open(image_file.name).convert('L')
else:
img = Image.open(image_file).convert('L')
image = np.array(img)
results = {}
colors = {
"mcp_medsam": (0, 255, 0),
"medsam2": (255, 0, 0),
"sam_med3d": (0, 0, 255),
"medsam_3d": (255, 255, 0),
"nnunet": (255, 0, 255)
}
for model in models:
start = time.time()
try:
if model == "mcp_medsam":
result = run_mcp_medsam_2d(image, box, modality)
color = colors.get(model, (0, 255, 0))
overlay = overlay_mask_on_image(image, result["mask"], color=color)
elif model in ["medsam2", "sam_med3d", "medsam_3d", "nnunet"]:
result = run_mcp_medsam_2d(image, box, modality)
color = colors.get(model, (128, 128, 128))
overlay = overlay_mask_on_image(image, result["mask"], color=color)
else:
continue
results[model] = {
"success": True,
"mask_b64": result["mask_b64"],
"overlay_b64": image_to_base64(overlay),
"inference_time": round(time.time() - start, 2),
"shape": result["shape"]
}
except Exception as e:
results[model] = {"success": False, "error": str(e)}
return {"success": True, "results": results}
except Exception as e:
return {"error": str(e)}
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
def create_interface():
"""Create Gradio interface with sample images and all models."""
with gr.Blocks(
title="NeuroSeg Server - HydroMorph Backend",
theme=gr.themes.Soft(),
css="""
.sample-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; }
.model-checkbox { margin: 5px 0; }
"""
) as demo:
gr.Markdown("""
# 🧠 NeuroSeg Server
Backend API for HydroMorph React Native app (iOS, Android, Web).
**MCP Server**: `https://mmrech-medsam2-server.hf.space/gradio_api/mcp/sse`
**Models**: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net
""")
# --- Try with Sample CT ---
with gr.Tab("πŸ“‹ Try with Sample CT"):
gr.Markdown("Select a sample CT scan to test the models:")
sample_radio = gr.Radio(
choices=[(f"{v['name']}: {v['description'][:50]}...", k) for k, v in SAMPLE_IMAGES.items()],
value="nph_1",
label="Select Sample"
)
with gr.Row():
sample_preview = gr.Image(label="Selected Sample", type="pil")
sample_info = gr.JSON(label="Sample Info")
def load_sample_preview(sample_id):
result = load_sample_image(sample_id)
if result is None:
return None, {}
img_array, meta = result
return Image.fromarray(img_array), meta
sample_radio.change(
fn=load_sample_preview,
inputs=[sample_radio],
outputs=[sample_preview, sample_info]
)
# Load initial sample
demo.load(
fn=lambda: load_sample_preview("nph_1"),
outputs=[sample_preview, sample_info]
)
gr.Markdown("### Use this sample in:")
with gr.Row():
use_in_single = gr.Button("🎯 Single Model", variant="secondary")
use_in_compare = gr.Button("πŸ”¬ Model Comparison", variant="secondary")
# --- Single Model ---
with gr.Tab("🎯 Single Model"):
with gr.Row():
with gr.Column(scale=1):
# Input source
input_source = gr.Radio(
choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
value="sample",
label="Input Source"
)
single_sample = gr.Dropdown(
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
value="nph_1",
label="Sample",
visible=True
)
single_upload = gr.Image(
label="Upload",
type="filepath",
visible=False
)
def toggle_input(source):
return {
single_sample: gr.update(visible=source == "sample"),
single_upload: gr.update(visible=source == "upload")
}
input_source.change(fn=toggle_input, inputs=[input_source], outputs=[single_sample, single_upload])
# Model selection
enabled_models = [(v.name, k) for k, v in MODELS_CONFIG.items() if v.enabled]
single_model = gr.Dropdown(
choices=enabled_models,
value=enabled_models[0][1] if enabled_models else None,
label="Model"
)
# Dynamic inputs based on model
single_box = gr.Textbox(
label="Bounding Box (JSON)",
value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]),
visible=True
)
single_modality = gr.Dropdown(
label="Modality",
choices=list(MODALITY_MAP.keys()),
value="CT",
visible=True
)
single_prompt_only = gr.Checkbox(
label="Show only prompt-based models",
value=False
)
single_run = gr.Button("πŸš€ Run Model", variant="primary")
with gr.Column(scale=2):
single_output = gr.JSON(label="Result")
single_overlay = gr.Image(label="Segmentation Overlay")
def run_single(source, sample, upload, model, box, modality):
if source == "sample":
img_path = get_sample_image_path(sample)
else:
img_path = upload
if img_path is None:
return {"error": "No image provided"}, None
result = api_segment_2d(img_path, box, model, modality)
if "error" in result:
return result, None
# Generate overlay
img = Image.open(img_path).convert('L')
mask = decompress_mask(result["mask_b64"])
overlay = overlay_mask_on_image(np.array(img), mask)
return result, overlay
single_run.click(
fn=run_single,
inputs=[input_source, single_sample, single_upload, single_model, single_box, single_modality],
outputs=[single_output, single_overlay]
)
# --- Model Comparison ---
with gr.Tab("πŸ”¬ Model Comparison"):
with gr.Row():
with gr.Column(scale=1):
comp_input_source = gr.Radio(
choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
value="sample",
label="Input Source"
)
comp_sample = gr.Dropdown(
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
value="nph_1",
label="Sample",
visible=True
)
comp_upload = gr.Image(
label="Upload",
type="filepath",
visible=False
)
comp_input_source.change(
fn=lambda x: {comp_sample: gr.update(visible=x == "sample"), comp_upload: gr.update(visible=x == "upload")},
inputs=[comp_input_source],
outputs=[comp_sample, comp_upload]
)
comp_box = gr.Textbox(
label="Bounding Box (JSON)",
value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"])
)
comp_modality = gr.Dropdown(
label="Modality",
choices=list(MODALITY_MAP.keys()),
value="CT"
)
# Model selection with categories
gr.Markdown("### Select Models to Compare")
comp_prompt_only = gr.Checkbox(
label="Prompt-based models only",
value=False,
info="Filter to models that accept prompts"
)
# Foundation models
gr.Markdown("**Foundation Models**")
comp_medsam2 = gr.Checkbox(label="MedSAM2 (3D Bi-directional)", value=True)
comp_mcp = gr.Checkbox(label="MCP-MedSAM (Fast 2D)", value=True)
comp_sam3d = gr.Checkbox(label="SAM-Med3D (245+ classes)", value=False)
comp_medsam3d = gr.Checkbox(label="MedSAM-3D (Memory Bank)", value=False)
# Specialized models
gr.Markdown("**Specialized Models**")
comp_tractseg = gr.Checkbox(label="TractSeg (72 bundles)", value=False)
comp_nnunet = gr.Checkbox(label="nnU-Net (Auto-configuring)", value=False)
comp_run = gr.Button("πŸš€ Run Comparison", variant="primary")
with gr.Column(scale=2):
comp_output = gr.JSON(label="Comparison Results")
comp_gallery = gr.Gallery(label="Model Overlays")
def run_comparison(source, sample, upload, box, modality, prompt_only, *model_flags):
models = []
model_names = ["medsam2", "mcp_medsam", "sam_med3d", "medsam_3d", "tractseg", "nnunet"]
for name, enabled in zip(model_names, model_flags):
if enabled:
# Skip non-prompt models if prompt_only is checked
if prompt_only and not MODELS_CONFIG[name].needs_prompt:
continue
models.append(name)
if not models:
return {"error": "No models selected"}, []
if source == "sample":
img_path = get_sample_image_path(sample)
else:
img_path = upload
if img_path is None:
return {"error": "No image provided"}, []
result = api_compare_models(img_path, box, json.dumps(models), modality)
if "error" in result:
return result, []
# Extract gallery images
gallery = []
for model, data in result.get("results", {}).items():
if data.get("success") and "overlay_b64" in data:
img = base64_to_image(data["overlay_b64"])
gallery.append((img, f"{model} ({data.get('inference_time', 0)}s)"))
return result, gallery
comp_run.click(
fn=run_comparison,
inputs=[
comp_input_source, comp_sample, comp_upload,
comp_box, comp_modality, comp_prompt_only,
comp_medsam2, comp_mcp, comp_sam3d, comp_medsam3d,
comp_tractseg, comp_nnunet
],
outputs=[comp_output, comp_gallery]
)
# --- 3D Segmentation ---
with gr.Tab("πŸ₯ 3D Segmentation"):
with gr.Row():
with gr.Column():
seg3d_volume = gr.File(label="Volume (.npy or .nii.gz)", file_types=[".npy", ".nii.gz"])
seg3d_box = gr.Textbox(
label="Box + Slice (JSON)",
value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200, "slice_idx": 32}'
)
seg3d_model = gr.Dropdown(
label="Model",
choices=[(v.name, k) for k, v in MODELS_CONFIG.items() if v.supports_3d and v.enabled],
value="medsam2"
)
seg3d_run = gr.Button("Segment", variant="primary")
with gr.Column():
seg3d_output = gr.JSON(label="Result")
seg3d_run.click(
fn=api_segment_3d,
inputs=[seg3d_volume, seg3d_box, seg3d_model],
outputs=seg3d_output
)
# --- Mobile App Endpoint ---
with gr.Tab("πŸ“± Mobile App"):
gr.Markdown("""
This endpoint is used by the HydroMorph mobile app.
**Endpoint:** `POST /gradio_api/call/process_with_status`
""")
with gr.Row():
with gr.Column():
mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
mobile_window = gr.Dropdown(
label="Window",
choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
value="Brain (Grey Matter)"
)
mobile_btn = gr.Button("Run Segmentation", variant="primary")
with gr.Column():
mobile_result_img = gr.Image(label="Result")
mobile_status = gr.Textbox(label="Status")
mobile_btn.click(
fn=process_with_status,
inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
outputs=[mobile_result_img, mobile_status],
api_name="process_with_status"
)
# --- Status ---
with gr.Tab("βš™οΈ Status"):
health_btn = gr.Button("Check Health")
health_output = gr.JSON(label="System Status")
# Model status table
gr.Markdown("### Model Status")
model_status_data = [
[v.name, "βœ… Enabled" if v.enabled else "❌ Disabled", v.category, "Yes" if v.needs_prompt else "No"]
for k, v in MODELS_CONFIG.items()
]
gr.Dataframe(
headers=["Model", "Status", "Category", "Needs Prompt"],
value=model_status_data
)
health_btn.click(fn=api_health, outputs=health_output, api_name="health")
return demo
# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}")
logger.info(f"Samples configured: {list(SAMPLE_IMAGES.keys())}")
demo = create_interface()
# Launch with MCP server support
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True,
quiet=False,
mcp_server=True # Enable MCP server
)