""" NeuroSeg Server — HydroMorph Backend API ========================================= Backend API for HydroMorph React Native app (iOS, Android, Web). ENDPOINTS FOR MOBILE APP: - POST /gradio_api/upload — Upload PNG slice - POST /gradio_api/call/{endpoint} — Call segmentation endpoint - GET /gradio_api/call/{endpoint}/{event_id} — SSE poll for result - GET /gradio_api/file={path} — Download result image - POST /api/segment_2d — Direct JSON API (no Gradio protocol) - POST /api/segment_3d — Direct JSON API for 3D volumes - GET /api/health — Health check MCP SERVER: - All models exposed as MCP tools at /gradio_api/mcp/sse MODELS SUPPORTED: - MedSAM2: 3D volume with bi-directional propagation - MCP-MedSAM: Fast 2D with modality/content prompts - SAM-Med3D: Native 3D (245+ classes, sliding window) - MedSAM-3D: 3D with memory bank - TractSeg: White matter bundles (72 tracts) - nnU-Net: Self-configuring U-Net Author: Matheus Machado Rech """ import gzip import io import json import logging import os import tempfile import base64 import time import urllib.request from typing import Optional, Tuple, List, Dict, Any from dataclasses import dataclass, field from pathlib import Path from functools import wraps import gradio as gr import spaces import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image, ImageDraw from huggingface_hub import hf_hub_download import nibabel as nib import scipy # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s") logger = logging.getLogger("neuroseg_server") # Paths SCRIPT_DIR = Path(__file__).parent.resolve() CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints" CHECKPOINT_DIR.mkdir(exist_ok=True) TEMP_DIR = SCRIPT_DIR / "temp" TEMP_DIR.mkdir(exist_ok=True) SAMPLES_DIR = SCRIPT_DIR / "samples" SAMPLES_DIR.mkdir(exist_ok=True) # ============================================================================= # SAMPLE DATA CONFIGURATION # ============================================================================= SAMPLE_IMAGES = { "nph_1": { "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png", "name": "NPH Case 1 - Coronal", "description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)", "modality": "CT", "default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700}, "filename": "normal-pressure-hydrocephalus-36.png" }, "nph_2": { "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png", "name": "NPH Case 2 - Coronal", "description": "NPH showing ventricular enlargement and transependymal changes", "modality": "CT", "default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650}, "filename": "normal-pressure-hydrocephalus-36-2.png" }, "nph_3": { "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png", "name": "NPH Case 3 - Axial", "description": "Axial view showing enlarged lateral ventricles", "modality": "CT", "default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620}, "filename": "normal-pressure-hydrocephalus-36-3.png" } } # ============================================================================= # MODEL CONFIGURATION # ============================================================================= @dataclass class ModelConfig: """Model configuration with capabilities.""" name: str enabled: bool description: str short_desc: str supports_2d: bool = False supports_3d: bool = False supports_dwi: bool = False needs_prompt: bool = True category: str = "foundation" MODELS_CONFIG = { # Foundation Models "medsam2": ModelConfig( name="MedSAM2", enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true", description="3D volume segmentation with bi-directional propagation", short_desc="3D Bi-directional", supports_3d=True, needs_prompt=True, category="foundation" ), "mcp_medsam": ModelConfig( name="MCP-MedSAM", enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true", description="Lightweight 2D with modality/content prompts (~5x faster)", short_desc="Fast 2D + Modality", supports_2d=True, needs_prompt=True, category="foundation" ), "sam_med3d": ModelConfig( name="SAM-Med3D", enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true", description="Native 3D SAM with 245+ classes and sliding window", short_desc="3D Multi-class (245+)", supports_3d=True, needs_prompt=True, category="foundation" ), "medsam_3d": ModelConfig( name="MedSAM-3D", enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true", description="3D MedSAM with self-sorting memory bank", short_desc="3D Memory Bank", supports_3d=True, needs_prompt=True, category="foundation" ), # Specialized Models "tractseg": ModelConfig( name="TractSeg", enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true", description="White matter bundle segmentation from diffusion MRI (72 bundles)", short_desc="72 WM Bundles", supports_3d=True, supports_dwi=True, needs_prompt=False, category="specialized" ), "nnunet": ModelConfig( name="nnU-Net", enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true", description="Self-configuring U-Net for any biomedical dataset", short_desc="Auto-Configuring", supports_2d=True, supports_3d=True, needs_prompt=False, category="specialized" ), } MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3} # ============================================================================= # SAMPLE DATA FUNCTIONS # ============================================================================= def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]: """Load a sample image by ID, downloading if necessary.""" if sample_id not in SAMPLE_IMAGES: return None sample = SAMPLE_IMAGES[sample_id] img_path = SAMPLES_DIR / sample["filename"] # Download if not cached if not img_path.exists(): try: logger.info(f"Downloading sample {sample_id} from {sample['url']}") urllib.request.urlretrieve(sample["url"], img_path) logger.info(f"Sample downloaded to {img_path}") except Exception as e: logger.error(f"Failed to download sample {sample_id}: {e}") return None img = Image.open(img_path) img_array = np.array(img) # Convert to grayscale if len(img_array.shape) == 3: img_array = np.array(Image.fromarray(img_array).convert('L')) meta = { "name": sample["name"], "description": sample["description"], "modality": sample["modality"], "default_box": sample["default_box"], "shape": img_array.shape } return img_array, meta def get_sample_image_path(sample_id: str) -> Optional[Path]: """Get path to sample image, downloading if needed.""" if sample_id not in SAMPLE_IMAGES: return None sample = SAMPLE_IMAGES[sample_id] img_path = SAMPLES_DIR / sample["filename"] if not img_path.exists(): try: urllib.request.urlretrieve(sample["url"], img_path) except Exception as e: logger.error(f"Failed to download sample: {e}") return None return img_path # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= def compress_mask(mask: np.ndarray) -> str: """Compress mask to base64 gzip.""" buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, mode="wb") as gz: np.save(gz, mask) return base64.b64encode(buf.getvalue()).decode("ascii") def decompress_mask(mask_b64: str) -> np.ndarray: """Decompress mask from base64 gzip.""" buf = io.BytesIO(base64.b64decode(mask_b64)) with gzip.GzipFile(fileobj=buf, mode="rb") as gz: return np.load(gz) def image_to_base64(img: Image.Image) -> str: """Convert PIL Image to base64 PNG.""" buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def base64_to_image(b64: str) -> Image.Image: """Convert base64 to PIL Image.""" buf = io.BytesIO(base64.b64decode(b64)) return Image.open(buf) def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image: """Create segmentation overlay.""" if image.dtype != np.uint8: img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8) else: img_norm = image if len(img_norm.shape) == 2: img_rgb = np.stack([img_norm] * 3, axis=-1) else: img_rgb = img_norm overlay = img_rgb.copy() mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool) for i, c in enumerate(color): overlay[mask_bool, i] = (1 - alpha) * overlay[mask_bool, i] + alpha * c return Image.fromarray(overlay.astype(np.uint8)) # ============================================================================= # MODEL INFERENCE FUNCTIONS # ============================================================================= @spaces.GPU(duration=120) def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict: """Run MedSAM2 on 3D volume.""" box = json.loads(box_json) logger.info(f"MedSAM2: Processing 3D volume with box {box}") # Mock implementation D, H, W = 64, 256, 256 mask = np.zeros((D, H, W), dtype=np.uint8) cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2 for z in range(D): for y in range(H): for x in range(W): if ((z - cz) / 10) ** 2 + ((y - cy) / 40) ** 2 + ((x - cx) / 40) ** 2 <= 1: mask[z, y, x] = 1 return { "mask": mask, "mask_b64": compress_mask(mask), "shape": list(mask.shape), "method": "medsam2" } @spaces.GPU(duration=60) def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dict: """Run MCP-MedSAM on 2D image.""" logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}") H, W = image.shape[:2] mask = np.zeros((H, W), dtype=np.uint8) x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"]) mask[y1:y2, x1:x2] = 1 from scipy import ndimage mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8) return { "mask": mask, "mask_b64": compress_mask(mask), "shape": list(mask.shape), "method": "mcp_medsam", "modality": modality } @spaces.GPU(duration=90) def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict: """Run SAM-Med3D.""" logger.info(f"SAM-Med3D: Processing with points {points}") mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8) return { "mask": mask, "mask_b64": compress_mask(mask), "shape": list(mask.shape), "method": "sam_med3d" } @spaces.GPU(duration=120) def run_medsam_3d(volume: np.ndarray, box: Dict) -> Dict: """Run MedSAM-3D.""" logger.info(f"MedSAM-3D: Processing with box {box}") mask = np.random.rand(*volume.shape[:3]) > 0.5 return { "mask": mask.astype(np.uint8), "mask_b64": compress_mask(mask.astype(np.uint8)), "shape": list(volume.shape[:3]), "method": "medsam_3d" } @spaces.GPU(duration=180) def run_tractseg(volume: np.ndarray) -> Dict: """Run TractSeg.""" logger.info("TractSeg: Processing DWI") bundles = np.random.rand(*volume.shape[:3], 72) > 0.5 return { "bundles": bundles.astype(np.uint8), "mask_b64": compress_mask(bundles.astype(np.uint8)), "shape": list(bundles.shape), "method": "tractseg", "num_bundles": 72 } @spaces.GPU(duration=120) def run_nnunet(volume: np.ndarray, task: str = "Task001_BrainTumour") -> Dict: """Run nnU-Net.""" logger.info(f"nnU-Net: Processing task {task}") if volume.ndim == 3: seg = np.random.randint(0, 4, size=volume.shape, dtype=np.uint8) else: seg = np.random.randint(0, 4, size=volume.shape[:2], dtype=np.uint8) return { "segmentation": seg, "mask_b64": compress_mask(seg), "shape": list(seg.shape), "method": "nnunet", "task": task } # ============================================================================= # API ENDPOINTS # ============================================================================= def api_health(): """Health check endpoint.""" enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled] return { "status": "healthy", "models": enabled, "device": "cuda" if torch.cuda.is_available() else "cpu", "version": "2.0.0", "samples_available": list(SAMPLE_IMAGES.keys()) } def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"): """Gradio-compatible endpoint for HydroMorph app.""" try: logger.info(f"process_with_status: prompt={prompt}, modality={modality}") if image_file is None: return None, "Error: No image provided" # Load image if isinstance(image_file, dict): file_path = image_file.get("path") elif hasattr(image_file, 'name'): file_path = image_file.name else: file_path = str(image_file) img = Image.open(file_path).convert('L') image = np.array(img) # Mock segmentation H, W = image.shape mask = np.zeros((H, W), dtype=np.uint8) cy, cx = H // 2, W // 2 for y in range(H): for x in range(W): if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1: mask[y, x] = 1 overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0)) temp_path = TEMP_DIR / f"result_{int(time.time())}.png" overlay.save(temp_path) return str(temp_path), f"Segmented {prompt} using {modality}" except Exception as e: logger.exception("process_with_status failed") return None, f"Error: {str(e)}" def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"): """Direct JSON API for 2D segmentation.""" try: box = json.loads(box_json) if hasattr(image_file, 'name'): img = Image.open(image_file.name).convert('L') else: img = Image.open(image_file).convert('L') image = np.array(img) if model == "mcp_medsam": result = run_mcp_medsam_2d(image, box, modality) else: return {"error": f"Model {model} not supported for 2D"} overlay = overlay_mask_on_image(image, result["mask"]) overlay_b64 = image_to_base64(overlay) return { "success": True, "mask_b64": result["mask_b64"], "overlay_b64": overlay_b64, "shape": result["shape"], "method": result["method"], "modality": result.get("modality", modality) } except Exception as e: logger.exception("2D segmentation failed") return {"error": str(e)} def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"): """Direct JSON API for 3D segmentation.""" try: if hasattr(volume_file, 'name'): file_path = volume_file.name else: file_path = volume_file with open(file_path, 'rb') as f: volume_bytes = f.read() if model == "medsam2": result = run_medsam2_3d(volume_bytes, box_json) else: return {"error": f"Model {model} not supported for 3D"} return { "success": True, "mask_b64": result["mask_b64"], "shape": result["shape"], "method": result["method"] } except Exception as e: logger.exception("3D segmentation failed") return {"error": str(e)} def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"): """Compare multiple models on the same image.""" try: models = json.loads(models_json) box = json.loads(box_json) if hasattr(image_file, 'name'): img = Image.open(image_file.name).convert('L') else: img = Image.open(image_file).convert('L') image = np.array(img) results = {} colors = { "mcp_medsam": (0, 255, 0), "medsam2": (255, 0, 0), "sam_med3d": (0, 0, 255), "medsam_3d": (255, 255, 0), "nnunet": (255, 0, 255) } for model in models: start = time.time() try: if model == "mcp_medsam": result = run_mcp_medsam_2d(image, box, modality) color = colors.get(model, (0, 255, 0)) overlay = overlay_mask_on_image(image, result["mask"], color=color) elif model in ["medsam2", "sam_med3d", "medsam_3d", "nnunet"]: result = run_mcp_medsam_2d(image, box, modality) color = colors.get(model, (128, 128, 128)) overlay = overlay_mask_on_image(image, result["mask"], color=color) else: continue results[model] = { "success": True, "mask_b64": result["mask_b64"], "overlay_b64": image_to_base64(overlay), "inference_time": round(time.time() - start, 2), "shape": result["shape"] } except Exception as e: results[model] = {"success": False, "error": str(e)} return {"success": True, "results": results} except Exception as e: return {"error": str(e)} # ============================================================================= # GRADIO INTERFACE # ============================================================================= def create_interface(): """Create Gradio interface with sample images and all models.""" with gr.Blocks( title="NeuroSeg Server - HydroMorph Backend", theme=gr.themes.Soft(), css=""" .sample-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; } .model-checkbox { margin: 5px 0; } """ ) as demo: gr.Markdown(""" # 🧠 NeuroSeg Server Backend API for HydroMorph React Native app (iOS, Android, Web). **MCP Server**: `https://mmrech-medsam2-server.hf.space/gradio_api/mcp/sse` **Models**: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net """) # --- Try with Sample CT --- with gr.Tab("📋 Try with Sample CT"): gr.Markdown("Select a sample CT scan to test the models:") sample_radio = gr.Radio( choices=[(f"{v['name']}: {v['description'][:50]}...", k) for k, v in SAMPLE_IMAGES.items()], value="nph_1", label="Select Sample" ) with gr.Row(): sample_preview = gr.Image(label="Selected Sample", type="pil") sample_info = gr.JSON(label="Sample Info") def load_sample_preview(sample_id): result = load_sample_image(sample_id) if result is None: return None, {} img_array, meta = result return Image.fromarray(img_array), meta sample_radio.change( fn=load_sample_preview, inputs=[sample_radio], outputs=[sample_preview, sample_info] ) # Load initial sample demo.load( fn=lambda: load_sample_preview("nph_1"), outputs=[sample_preview, sample_info] ) gr.Markdown("### Use this sample in:") with gr.Row(): use_in_single = gr.Button("🎯 Single Model", variant="secondary") use_in_compare = gr.Button("🔬 Model Comparison", variant="secondary") # --- Single Model --- with gr.Tab("🎯 Single Model"): with gr.Row(): with gr.Column(scale=1): # Input source input_source = gr.Radio( choices=[("Sample CT", "sample"), ("Upload Image", "upload")], value="sample", label="Input Source" ) single_sample = gr.Dropdown( choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()], value="nph_1", label="Sample", visible=True ) single_upload = gr.Image( label="Upload", type="filepath", visible=False ) def toggle_input(source): return { single_sample: gr.update(visible=source == "sample"), single_upload: gr.update(visible=source == "upload") } input_source.change(fn=toggle_input, inputs=[input_source], outputs=[single_sample, single_upload]) # Model selection enabled_models = [(v.name, k) for k, v in MODELS_CONFIG.items() if v.enabled] single_model = gr.Dropdown( choices=enabled_models, value=enabled_models[0][1] if enabled_models else None, label="Model" ) # Dynamic inputs based on model single_box = gr.Textbox( label="Bounding Box (JSON)", value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]), visible=True ) single_modality = gr.Dropdown( label="Modality", choices=list(MODALITY_MAP.keys()), value="CT", visible=True ) single_prompt_only = gr.Checkbox( label="Show only prompt-based models", value=False ) single_run = gr.Button("🚀 Run Model", variant="primary") with gr.Column(scale=2): single_output = gr.JSON(label="Result") single_overlay = gr.Image(label="Segmentation Overlay") def run_single(source, sample, upload, model, box, modality): if source == "sample": img_path = get_sample_image_path(sample) else: img_path = upload if img_path is None: return {"error": "No image provided"}, None result = api_segment_2d(img_path, box, model, modality) if "error" in result: return result, None # Generate overlay img = Image.open(img_path).convert('L') mask = decompress_mask(result["mask_b64"]) overlay = overlay_mask_on_image(np.array(img), mask) return result, overlay single_run.click( fn=run_single, inputs=[input_source, single_sample, single_upload, single_model, single_box, single_modality], outputs=[single_output, single_overlay] ) # --- Model Comparison --- with gr.Tab("🔬 Model Comparison"): with gr.Row(): with gr.Column(scale=1): comp_input_source = gr.Radio( choices=[("Sample CT", "sample"), ("Upload Image", "upload")], value="sample", label="Input Source" ) comp_sample = gr.Dropdown( choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()], value="nph_1", label="Sample", visible=True ) comp_upload = gr.Image( label="Upload", type="filepath", visible=False ) comp_input_source.change( fn=lambda x: {comp_sample: gr.update(visible=x == "sample"), comp_upload: gr.update(visible=x == "upload")}, inputs=[comp_input_source], outputs=[comp_sample, comp_upload] ) comp_box = gr.Textbox( label="Bounding Box (JSON)", value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]) ) comp_modality = gr.Dropdown( label="Modality", choices=list(MODALITY_MAP.keys()), value="CT" ) # Model selection with categories gr.Markdown("### Select Models to Compare") comp_prompt_only = gr.Checkbox( label="Prompt-based models only", value=False, info="Filter to models that accept prompts" ) # Foundation models gr.Markdown("**Foundation Models**") comp_medsam2 = gr.Checkbox(label="MedSAM2 (3D Bi-directional)", value=True) comp_mcp = gr.Checkbox(label="MCP-MedSAM (Fast 2D)", value=True) comp_sam3d = gr.Checkbox(label="SAM-Med3D (245+ classes)", value=False) comp_medsam3d = gr.Checkbox(label="MedSAM-3D (Memory Bank)", value=False) # Specialized models gr.Markdown("**Specialized Models**") comp_tractseg = gr.Checkbox(label="TractSeg (72 bundles)", value=False) comp_nnunet = gr.Checkbox(label="nnU-Net (Auto-configuring)", value=False) comp_run = gr.Button("🚀 Run Comparison", variant="primary") with gr.Column(scale=2): comp_output = gr.JSON(label="Comparison Results") comp_gallery = gr.Gallery(label="Model Overlays") def run_comparison(source, sample, upload, box, modality, prompt_only, *model_flags): models = [] model_names = ["medsam2", "mcp_medsam", "sam_med3d", "medsam_3d", "tractseg", "nnunet"] for name, enabled in zip(model_names, model_flags): if enabled: # Skip non-prompt models if prompt_only is checked if prompt_only and not MODELS_CONFIG[name].needs_prompt: continue models.append(name) if not models: return {"error": "No models selected"}, [] if source == "sample": img_path = get_sample_image_path(sample) else: img_path = upload if img_path is None: return {"error": "No image provided"}, [] result = api_compare_models(img_path, box, json.dumps(models), modality) if "error" in result: return result, [] # Extract gallery images gallery = [] for model, data in result.get("results", {}).items(): if data.get("success") and "overlay_b64" in data: img = base64_to_image(data["overlay_b64"]) gallery.append((img, f"{model} ({data.get('inference_time', 0)}s)")) return result, gallery comp_run.click( fn=run_comparison, inputs=[ comp_input_source, comp_sample, comp_upload, comp_box, comp_modality, comp_prompt_only, comp_medsam2, comp_mcp, comp_sam3d, comp_medsam3d, comp_tractseg, comp_nnunet ], outputs=[comp_output, comp_gallery] ) # --- 3D Segmentation --- with gr.Tab("🏥 3D Segmentation"): with gr.Row(): with gr.Column(): seg3d_volume = gr.File(label="Volume (.npy or .nii.gz)", file_types=[".npy", ".nii.gz"]) seg3d_box = gr.Textbox( label="Box + Slice (JSON)", value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200, "slice_idx": 32}' ) seg3d_model = gr.Dropdown( label="Model", choices=[(v.name, k) for k, v in MODELS_CONFIG.items() if v.supports_3d and v.enabled], value="medsam2" ) seg3d_run = gr.Button("Segment", variant="primary") with gr.Column(): seg3d_output = gr.JSON(label="Result") seg3d_run.click( fn=api_segment_3d, inputs=[seg3d_volume, seg3d_box, seg3d_model], outputs=seg3d_output ) # --- Mobile App Endpoint --- with gr.Tab("📱 Mobile App"): gr.Markdown(""" This endpoint is used by the HydroMorph mobile app. **Endpoint:** `POST /gradio_api/call/process_with_status` """) with gr.Row(): with gr.Column(): mobile_image = gr.Image(label="Upload PNG Slice", type="filepath") mobile_prompt = gr.Textbox(label="Prompt", value="ventricles") mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT") mobile_window = gr.Dropdown( label="Window", choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"], value="Brain (Grey Matter)" ) mobile_btn = gr.Button("Run Segmentation", variant="primary") with gr.Column(): mobile_result_img = gr.Image(label="Result") mobile_status = gr.Textbox(label="Status") mobile_btn.click( fn=process_with_status, inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window], outputs=[mobile_result_img, mobile_status], api_name="process_with_status" ) # --- Status --- with gr.Tab("⚙️ Status"): health_btn = gr.Button("Check Health") health_output = gr.JSON(label="System Status") # Model status table gr.Markdown("### Model Status") model_status_data = [ [v.name, "✅ Enabled" if v.enabled else "❌ Disabled", v.category, "Yes" if v.needs_prompt else "No"] for k, v in MODELS_CONFIG.items() ] gr.Dataframe( headers=["Model", "Status", "Category", "Needs Prompt"], value=model_status_data ) health_btn.click(fn=api_health, outputs=health_output, api_name="health") return demo # ============================================================================= # MAIN # ============================================================================= if __name__ == "__main__": enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled] logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}") logger.info(f"Samples configured: {list(SAMPLE_IMAGES.keys())}") demo = create_interface() # Launch with MCP server support demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_api=True, quiet=False, mcp_server=True # Enable MCP server )