Spaces:
Runtime error
Runtime error
Refactor for HydroMorph mobile app compatibility
Browse files- Add process_with_status endpoint for GradioClient.js compatibility
- Add api_segment_2d/3d for direct JSON API calls
- Add api_compare_models for multi-model comparison
- Add overlay_mask_on_image for result visualization
- Add compress_mask/decompress_mask utilities
- Add image_to_base64/base64_to_image utilities
- Implement run_medsam2_3d, run_mcp_medsam_2d, run_sam_med3d
- Update Gradio UI with mobile app, 2D, 3D, comparison tabs
- Ensure API returns [image, status] format expected by mobile app
- Add proper error handling and logging for mobile debugging
- .DS_Store +0 -0
- __pycache__/app.cpython-313.pyc +0 -0
- app.py +487 -794
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
__pycache__/app.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -1,19 +1,24 @@
|
|
| 1 |
"""
|
| 2 |
-
NeuroSeg Server —
|
| 3 |
-
=========================================
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
MODELS SUPPORTED:
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
|
| 13 |
-
Specialized Models:
|
| 14 |
-
- TractSeg: White matter bundle segmentation (72 bundles)
|
| 15 |
-
- nnU-Net: Self-configuring U-Net for any biomedical dataset
|
| 16 |
-
- NeuroSAM3: Advanced neuroimage segmentation (placeholder)
|
| 17 |
|
| 18 |
Author: Matheus Machado Rech
|
| 19 |
"""
|
|
@@ -25,11 +30,11 @@ import logging
|
|
| 25 |
import os
|
| 26 |
import tempfile
|
| 27 |
import base64
|
| 28 |
-
import
|
| 29 |
-
from typing import Optional, Tuple, List, Dict, Any
|
| 30 |
from dataclasses import dataclass, field
|
| 31 |
-
from enum import Enum
|
| 32 |
from pathlib import Path
|
|
|
|
| 33 |
|
| 34 |
import gradio as gr
|
| 35 |
import spaces
|
|
@@ -39,871 +44,559 @@ import torch.nn as nn
|
|
| 39 |
import torch.nn.functional as F
|
| 40 |
from PIL import Image, ImageDraw
|
| 41 |
from huggingface_hub import hf_hub_download
|
| 42 |
-
|
| 43 |
import nibabel as nib
|
| 44 |
import scipy
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
|
| 50 |
-
logging.basicConfig(
|
| 51 |
-
level=logging.INFO,
|
| 52 |
-
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
|
| 53 |
-
)
|
| 54 |
logger = logging.getLogger("neuroseg_server")
|
| 55 |
|
| 56 |
-
#
|
| 57 |
-
# Configuration & Feature Flags
|
| 58 |
-
# ---------------------------------------------------------------------------
|
| 59 |
-
|
| 60 |
SCRIPT_DIR = Path(__file__).parent.resolve()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
needs_prompt: bool = True
|
| 68 |
-
supports_2d: bool = False
|
| 69 |
-
supports_3d: bool = False
|
| 70 |
-
supports_dwi: bool = False
|
| 71 |
-
supports_multiclass: bool = False
|
| 72 |
-
supports_sliding_window: bool = False
|
| 73 |
-
realtime: bool = False # Fast enough for interactive use
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
@dataclass
|
| 77 |
-
class ModelConfig:
|
| 78 |
-
"""Configuration for each supported model."""
|
| 79 |
-
name: str
|
| 80 |
-
enabled: bool
|
| 81 |
-
description: str
|
| 82 |
-
short_desc: str
|
| 83 |
-
capabilities: ModelCapability
|
| 84 |
-
preferred_formats: List[str]
|
| 85 |
-
default_prompt: Dict = field(default_factory=dict)
|
| 86 |
-
category: str = "foundation" # foundation, specialized, experimental
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# Feature flags - control which models are available
|
| 90 |
MODELS = {
|
| 91 |
-
|
| 92 |
-
"
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
capabilities=ModelCapability(
|
| 98 |
-
needs_prompt=True,
|
| 99 |
-
supports_2d=False,
|
| 100 |
-
supports_3d=True,
|
| 101 |
-
supports_dwi=False,
|
| 102 |
-
realtime=False
|
| 103 |
-
),
|
| 104 |
-
preferred_formats=[".npy", ".nii.gz"],
|
| 105 |
-
default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220, "slice_idx": 32},
|
| 106 |
-
category="foundation"
|
| 107 |
-
),
|
| 108 |
-
"mcp_medsam": ModelConfig(
|
| 109 |
-
name="MCP-MedSAM",
|
| 110 |
-
enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true",
|
| 111 |
-
description="Lightweight 2D segmentation with explicit modality and content prompts (~5x faster)",
|
| 112 |
-
short_desc="Fast 2D + Modality",
|
| 113 |
-
capabilities=ModelCapability(
|
| 114 |
-
needs_prompt=True,
|
| 115 |
-
supports_2d=True,
|
| 116 |
-
supports_3d=False,
|
| 117 |
-
supports_dwi=False,
|
| 118 |
-
realtime=True
|
| 119 |
-
),
|
| 120 |
-
preferred_formats=[".png", ".jpg", ".jpeg", ".npy", ".nii.gz"],
|
| 121 |
-
default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220},
|
| 122 |
-
category="foundation"
|
| 123 |
-
),
|
| 124 |
-
"sam_med3d": ModelConfig(
|
| 125 |
-
name="SAM-Med3D",
|
| 126 |
-
enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true",
|
| 127 |
-
description="Native 3D SAM with 245+ classes and sliding window for large volumes",
|
| 128 |
-
short_desc="3D Multi-class (245+)",
|
| 129 |
-
capabilities=ModelCapability(
|
| 130 |
-
needs_prompt=True,
|
| 131 |
-
supports_2d=False,
|
| 132 |
-
supports_3d=True,
|
| 133 |
-
supports_multiclass=True,
|
| 134 |
-
supports_sliding_window=True,
|
| 135 |
-
realtime=False
|
| 136 |
-
),
|
| 137 |
-
preferred_formats=[".nii.gz"],
|
| 138 |
-
default_prompt={"points": [[64, 64, 64]], "labels": [1]},
|
| 139 |
-
category="foundation"
|
| 140 |
-
),
|
| 141 |
-
"medsam_3d": ModelConfig(
|
| 142 |
-
name="MedSAM-3D",
|
| 143 |
-
enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true",
|
| 144 |
-
description="3D MedSAM with self-sorting memory bank for consistent volumetric segmentation",
|
| 145 |
-
short_desc="3D Memory Bank",
|
| 146 |
-
capabilities=ModelCapability(
|
| 147 |
-
needs_prompt=True,
|
| 148 |
-
supports_2d=False,
|
| 149 |
-
supports_3d=True,
|
| 150 |
-
realtime=False
|
| 151 |
-
),
|
| 152 |
-
preferred_formats=[".nii.gz", ".npy"],
|
| 153 |
-
default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220, "slice_idx": 32},
|
| 154 |
-
category="foundation"
|
| 155 |
-
),
|
| 156 |
-
# Specialized Models
|
| 157 |
-
"tractseg": ModelConfig(
|
| 158 |
-
name="TractSeg",
|
| 159 |
-
enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true",
|
| 160 |
-
description="White matter bundle segmentation from diffusion MRI (72 bundles)",
|
| 161 |
-
short_desc="72 WM Bundles",
|
| 162 |
-
capabilities=ModelCapability(
|
| 163 |
-
needs_prompt=False, # Fully automatic
|
| 164 |
-
supports_2d=False,
|
| 165 |
-
supports_3d=True,
|
| 166 |
-
supports_dwi=True,
|
| 167 |
-
supports_multiclass=True,
|
| 168 |
-
realtime=False
|
| 169 |
-
),
|
| 170 |
-
preferred_formats=[".nii.gz"],
|
| 171 |
-
default_prompt={},
|
| 172 |
-
category="specialized"
|
| 173 |
-
),
|
| 174 |
-
"nnunet": ModelConfig(
|
| 175 |
-
name="nnU-Net",
|
| 176 |
-
enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true",
|
| 177 |
-
description="Self-configuring U-Net that auto-tunes to any biomedical dataset (SOTA baseline)",
|
| 178 |
-
short_desc="Auto-Configuring",
|
| 179 |
-
capabilities=ModelCapability(
|
| 180 |
-
needs_prompt=False, # Task-based
|
| 181 |
-
supports_2d=True,
|
| 182 |
-
supports_3d=True,
|
| 183 |
-
supports_multiclass=True,
|
| 184 |
-
realtime=False
|
| 185 |
-
),
|
| 186 |
-
preferred_formats=[".nii.gz", ".npy"],
|
| 187 |
-
default_prompt={"task": "Task001_BrainTumour"},
|
| 188 |
-
category="specialized"
|
| 189 |
-
),
|
| 190 |
-
"neurosam3": ModelConfig(
|
| 191 |
-
name="NeuroSAM3",
|
| 192 |
-
enabled=os.getenv("ENABLE_NEUROSAM3", "false").lower() == "true",
|
| 193 |
-
description="Advanced neuroimage segmentation (pending configuration)",
|
| 194 |
-
short_desc="Advanced (Pending)",
|
| 195 |
-
capabilities=ModelCapability(
|
| 196 |
-
needs_prompt=True,
|
| 197 |
-
supports_2d=True,
|
| 198 |
-
supports_3d=True,
|
| 199 |
-
realtime=False
|
| 200 |
-
),
|
| 201 |
-
preferred_formats=[".nii.gz", ".npy", ".png", ".jpg"],
|
| 202 |
-
default_prompt={},
|
| 203 |
-
category="experimental"
|
| 204 |
-
),
|
| 205 |
}
|
| 206 |
|
| 207 |
-
|
| 208 |
-
SAMPLE_IMAGES = {
|
| 209 |
-
"nph_1": {
|
| 210 |
-
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png",
|
| 211 |
-
"name": "NPH Case 1 - Coronal",
|
| 212 |
-
"description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)",
|
| 213 |
-
"modality": "CT",
|
| 214 |
-
"default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700},
|
| 215 |
-
"filename": "normal-pressure-hydrocephalus-36.png"
|
| 216 |
-
},
|
| 217 |
-
"nph_2": {
|
| 218 |
-
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png",
|
| 219 |
-
"name": "NPH Case 2 - Coronal",
|
| 220 |
-
"description": "NPH showing ventricular enlargement and transependymal changes",
|
| 221 |
-
"modality": "CT",
|
| 222 |
-
"default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650},
|
| 223 |
-
"filename": "normal-pressure-hydrocephalus-36-2.png"
|
| 224 |
-
},
|
| 225 |
-
"nph_3": {
|
| 226 |
-
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png",
|
| 227 |
-
"name": "NPH Case 3 - Axial",
|
| 228 |
-
"description": "Axial view showing enlarged lateral ventricles",
|
| 229 |
-
"modality": "CT",
|
| 230 |
-
"default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620},
|
| 231 |
-
"filename": "normal-pressure-hydrocephalus-36-3.png"
|
| 232 |
-
}
|
| 233 |
-
}
|
| 234 |
|
| 235 |
-
#
|
| 236 |
-
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
|
| 246 |
-
# TractSeg bundles
|
| 247 |
-
TRACTSEG_BUNDLES = [
|
| 248 |
-
"AF_left", "AF_right", "ATR_left", "ATR_right", "CA",
|
| 249 |
-
"CC_1", "CC_2", "CC_3", "CC_4", "CC_5", "CC_6", "CC_7",
|
| 250 |
-
"CG_left", "CG_right", "CST_left", "CST_right", "MLF_left", "MLF_right",
|
| 251 |
-
"FPT_left", "FPT_right", "FX_left", "FX_right", "ICP_left", "ICP_right",
|
| 252 |
-
"IFO_left", "IFO_right", "ILF_left", "ILF_right", "MCP",
|
| 253 |
-
"OR_left", "OR_right", "POPT_left", "POPT_right", "SCP_left", "SCP_right",
|
| 254 |
-
"SLF_I_left", "SLF_I_right", "SLF_II_left", "SLF_II_right", "SLF_III_left", "SLF_III_right",
|
| 255 |
-
"STR_left", "STR_right", "UF_left", "UF_right", "CC",
|
| 256 |
-
"T_PREF_left", "T_PREF_right", "T_PREM_left", "T_PREM_right",
|
| 257 |
-
"T_PREC_left", "T_PREC_right", "T_POSTC_left", "T_POSTC_right",
|
| 258 |
-
"T_PAR_left", "T_PAR_right", "T_OCC_left", "T_OCC_right",
|
| 259 |
-
"ST_FO_left", "ST_FO_right", "ST_PREF_left", "ST_PREF_right",
|
| 260 |
-
"ST_PREM_left", "ST_PREM_right", "ST_PREC_left", "ST_PREC_right",
|
| 261 |
-
"ST_POSTC_left", "ST_POSTC_right", "ST_PAR_left", "ST_PAR_right",
|
| 262 |
-
"ST_OCC_left", "ST_OCC_right",
|
| 263 |
-
]
|
| 264 |
-
|
| 265 |
-
# ---------------------------------------------------------------------------
|
| 266 |
-
# Utility Functions
|
| 267 |
-
# ---------------------------------------------------------------------------
|
| 268 |
-
|
| 269 |
-
def get_enabled_models(category: Optional[str] = None, needs_prompt: Optional[bool] = None) -> Dict[str, ModelConfig]:
|
| 270 |
-
"""Get enabled models, optionally filtered by category or prompt requirement."""
|
| 271 |
-
models = {k: v for k, v in MODELS.items() if v.enabled}
|
| 272 |
-
if category:
|
| 273 |
-
models = {k: v for k, v in models.items() if v.category == category}
|
| 274 |
-
if needs_prompt is not None:
|
| 275 |
-
models = {k: v for k, v in models.items() if v.capabilities.needs_prompt == needs_prompt}
|
| 276 |
-
return models
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]:
|
| 280 |
-
"""Load a sample image by ID, downloading if necessary."""
|
| 281 |
-
if sample_id not in SAMPLE_IMAGES:
|
| 282 |
-
return None
|
| 283 |
-
|
| 284 |
-
sample = SAMPLE_IMAGES[sample_id]
|
| 285 |
-
img_path = SAMPLES_DIR / sample["filename"]
|
| 286 |
-
|
| 287 |
-
# Download if not cached
|
| 288 |
-
if not img_path.exists():
|
| 289 |
-
try:
|
| 290 |
-
import urllib.request
|
| 291 |
-
logger.info(f"Downloading sample {sample_id} from {sample['url']}")
|
| 292 |
-
SAMPLES_DIR.mkdir(exist_ok=True)
|
| 293 |
-
urllib.request.urlretrieve(sample["url"], img_path)
|
| 294 |
-
logger.info(f"Sample downloaded to {img_path}")
|
| 295 |
-
except Exception as e:
|
| 296 |
-
logger.error(f"Failed to download sample {sample_id}: {e}")
|
| 297 |
-
# Create a placeholder image
|
| 298 |
-
img_array = np.zeros((512, 512), dtype=np.uint8)
|
| 299 |
-
meta = {
|
| 300 |
-
"name": sample["name"],
|
| 301 |
-
"description": f"{sample['description']} (Download failed)",
|
| 302 |
-
"modality": sample["modality"],
|
| 303 |
-
"default_box": sample["default_box"],
|
| 304 |
-
"shape": img_array.shape,
|
| 305 |
-
"error": str(e)
|
| 306 |
-
}
|
| 307 |
-
return img_array, meta
|
| 308 |
-
|
| 309 |
-
img = Image.open(img_path)
|
| 310 |
-
img_array = np.array(img)
|
| 311 |
-
|
| 312 |
-
# Convert to grayscale if needed
|
| 313 |
-
if len(img_array.shape) == 3:
|
| 314 |
-
if img_array.shape[2] == 4: # RGBA
|
| 315 |
-
img_array = np.array(Image.fromarray(img_array).convert('L'))
|
| 316 |
-
elif img_array.shape[2] == 3: # RGB
|
| 317 |
-
img_array = np.array(Image.fromarray(img_array).convert('L'))
|
| 318 |
-
|
| 319 |
-
meta = {
|
| 320 |
-
"name": sample["name"],
|
| 321 |
-
"description": sample["description"],
|
| 322 |
-
"modality": sample["modality"],
|
| 323 |
-
"default_box": sample["default_box"],
|
| 324 |
-
"shape": img_array.shape
|
| 325 |
-
}
|
| 326 |
-
|
| 327 |
-
return img_array, meta
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
-
def create_comparison_visualization(results: Dict[str, Any], original_image: np.ndarray) -> Image.Image:
|
| 331 |
-
"""Create a side-by-side visualization of model comparisons."""
|
| 332 |
-
# This is a placeholder - in production would create actual comparison image
|
| 333 |
-
return Image.fromarray(original_image)
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
-
# [Previous utility functions: load_nifti, save_nifti, load_image, window_and_normalise,
|
| 337 |
-
# resize_slice, prepare_tensor_medsam2, preprocess_2d_for_mcp, SimpleMCPMedSAM, etc.]
|
| 338 |
-
# Keeping them for brevity - assuming they're similar to previous implementation
|
| 339 |
|
| 340 |
-
def
|
| 341 |
-
"""
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
affine = nii.affine
|
| 345 |
-
spacing = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0))
|
| 346 |
-
return data, affine, {"spacing": spacing.tolist(), "shape": list(data.shape)}
|
| 347 |
|
| 348 |
|
| 349 |
-
def
|
| 350 |
-
"""
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
#
|
| 360 |
-
#
|
| 361 |
|
| 362 |
@spaces.GPU(duration=120)
|
| 363 |
-
def
|
| 364 |
-
"""
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
prompt: Model-specific prompt (box, points, etc.)
|
| 371 |
-
**kwargs: Additional model-specific arguments
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
start_time = time.time()
|
| 378 |
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
mask = np.random.randint(0, 4, size=image.shape[:3] if image.ndim >= 3 else image.shape[:2])
|
| 396 |
-
elif model_name == "neurosam3":
|
| 397 |
-
mask = np.random.rand(*image.shape[:2]) > 0.5
|
| 398 |
-
else:
|
| 399 |
-
return {"error": f"Unknown model: {model_name}"}
|
| 400 |
|
| 401 |
-
|
|
|
|
|
|
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
|
|
|
| 408 |
|
| 409 |
return {
|
| 410 |
-
"
|
|
|
|
| 411 |
"shape": list(mask.shape),
|
| 412 |
-
"method":
|
| 413 |
-
"
|
| 414 |
-
"prompt_used": prompt,
|
| 415 |
}
|
| 416 |
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
|
| 443 |
-
value="nph_1",
|
| 444 |
-
label="Select Sample",
|
| 445 |
-
visible=True
|
| 446 |
-
)
|
| 447 |
-
|
| 448 |
-
# File upload (hidden by default)
|
| 449 |
-
comp_file_upload = gr.File(
|
| 450 |
-
label="Upload Image",
|
| 451 |
-
type="filepath",
|
| 452 |
-
file_types=[".png", ".jpg", ".jpeg", ".nii.gz", ".npy"],
|
| 453 |
-
visible=False
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
# Show/hide based on input type
|
| 457 |
-
def toggle_input(input_type):
|
| 458 |
-
return {
|
| 459 |
-
comp_sample_select: gr.update(visible=input_type == "Sample CT Scan"),
|
| 460 |
-
comp_file_upload: gr.update(visible=input_type == "Upload Image")
|
| 461 |
-
}
|
| 462 |
-
|
| 463 |
-
comp_input_type.change(
|
| 464 |
-
fn=toggle_input,
|
| 465 |
-
inputs=[comp_input_type],
|
| 466 |
-
outputs=[comp_sample_select, comp_file_upload]
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
# Model selection
|
| 470 |
-
gr.Markdown("### Model Selection")
|
| 471 |
-
|
| 472 |
-
comp_prompt_only = gr.Checkbox(
|
| 473 |
-
label="Prompt-based models only",
|
| 474 |
-
value=False,
|
| 475 |
-
info="Show only models that accept prompts (box, points)"
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
# Dynamic model checkboxes
|
| 479 |
-
enabled_models = get_enabled_models()
|
| 480 |
-
model_checkboxes = []
|
| 481 |
-
|
| 482 |
-
with gr.Group():
|
| 483 |
-
gr.Markdown("**Foundation Models**")
|
| 484 |
-
for key, config in enabled_models.items():
|
| 485 |
-
if config.category == "foundation":
|
| 486 |
-
cb = gr.Checkbox(
|
| 487 |
-
label=f"{config.name} - {config.short_desc}",
|
| 488 |
-
value=key in ["medsam2", "mcp_medsam"],
|
| 489 |
-
info=config.description[:80] + "..."
|
| 490 |
-
)
|
| 491 |
-
model_checkboxes.append((key, cb))
|
| 492 |
-
|
| 493 |
-
gr.Markdown("**Specialized Models**")
|
| 494 |
-
for key, config in enabled_models.items():
|
| 495 |
-
if config.category == "specialized":
|
| 496 |
-
cb = gr.Checkbox(
|
| 497 |
-
label=f"{config.name} - {config.short_desc}",
|
| 498 |
-
value=False,
|
| 499 |
-
info=config.description[:80] + "..."
|
| 500 |
-
)
|
| 501 |
-
model_checkboxes.append((key, cb))
|
| 502 |
-
|
| 503 |
-
# Prompt input (for prompt-based models)
|
| 504 |
-
with gr.Group():
|
| 505 |
-
gr.Markdown("### Prompt Configuration")
|
| 506 |
-
comp_box = gr.Textbox(
|
| 507 |
-
label="Bounding Box (JSON)",
|
| 508 |
-
value='{"x1": 450, "y1": 350, "x2": 750, "y2": 700}',
|
| 509 |
-
info="Format: {\"x1\": int, \"y1\": int, \"x2\": int, \"y2\": int}"
|
| 510 |
-
)
|
| 511 |
-
comp_modality = gr.Dropdown(
|
| 512 |
-
label="Modality (for MCP-MedSAM)",
|
| 513 |
-
choices=list(MODALITY_MAP.keys()),
|
| 514 |
-
value="CT"
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
comp_run_btn = gr.Button("🚀 Run Comparison", variant="primary", size="lg")
|
| 518 |
-
|
| 519 |
-
with gr.Column(scale=2):
|
| 520 |
-
# Results display
|
| 521 |
-
comp_status = gr.Textbox(label="Status", value="Ready", lines=2)
|
| 522 |
-
comp_results = gr.JSON(label="Comparison Results")
|
| 523 |
-
|
| 524 |
-
# Visualization
|
| 525 |
-
comp_viz = gr.Image(label="Side-by-Side Comparison", type="pil")
|
| 526 |
|
| 527 |
-
# Store model checkboxes for reference
|
| 528 |
return {
|
| 529 |
-
"
|
| 530 |
-
"
|
| 531 |
-
"
|
| 532 |
-
"
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
-
"modality": comp_modality,
|
| 536 |
-
"run_btn": comp_run_btn,
|
| 537 |
-
"status": comp_status,
|
| 538 |
-
"results": comp_results,
|
| 539 |
-
"viz": comp_viz
|
| 540 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
|
| 543 |
-
def
|
| 544 |
-
"""
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
-
|
|
|
|
| 553 |
|
| 554 |
-
#
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
)
|
| 560 |
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
single_sample = gr.Dropdown(
|
| 572 |
-
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
|
| 573 |
-
value="nph_1",
|
| 574 |
-
label="Sample",
|
| 575 |
-
visible=True
|
| 576 |
-
)
|
| 577 |
-
|
| 578 |
-
single_upload = gr.File(
|
| 579 |
-
label="Upload",
|
| 580 |
-
type="filepath",
|
| 581 |
-
file_types=[".png", ".jpg", ".jpeg", ".nii.gz", ".npy", ".dcm"],
|
| 582 |
-
visible=False
|
| 583 |
-
)
|
| 584 |
-
|
| 585 |
-
single_input_type.change(
|
| 586 |
-
fn=lambda x: {
|
| 587 |
-
single_sample: gr.update(visible=x == "Sample"),
|
| 588 |
-
single_upload: gr.update(visible=x == "Upload")
|
| 589 |
-
},
|
| 590 |
-
inputs=[single_input_type],
|
| 591 |
-
outputs=[single_sample, single_upload]
|
| 592 |
-
)
|
| 593 |
-
|
| 594 |
-
# Dynamic model-specific inputs
|
| 595 |
-
with gr.Group() as single_prompt_group:
|
| 596 |
-
gr.Markdown("### Model Configuration")
|
| 597 |
-
|
| 598 |
-
# These will be shown/hidden based on model selection
|
| 599 |
-
single_box = gr.Textbox(
|
| 600 |
-
label="Bounding Box",
|
| 601 |
-
value='{"x1": 450, "y1": 350, "x2": 750, "y2": 700}',
|
| 602 |
-
visible=True
|
| 603 |
-
)
|
| 604 |
-
|
| 605 |
-
single_slice_idx = gr.Number(
|
| 606 |
-
label="Slice Index (for 3D)",
|
| 607 |
-
value=32,
|
| 608 |
-
precision=0,
|
| 609 |
-
visible=True
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
single_modality = gr.Dropdown(
|
| 613 |
-
label="Modality",
|
| 614 |
-
choices=list(MODALITY_MAP.keys()),
|
| 615 |
-
value="CT",
|
| 616 |
-
visible=False
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
-
single_task = gr.Textbox(
|
| 620 |
-
label="nnU-Net Task",
|
| 621 |
-
value="Task001_BrainTumour",
|
| 622 |
-
visible=False
|
| 623 |
-
)
|
| 624 |
-
|
| 625 |
-
single_points = gr.Textbox(
|
| 626 |
-
label="Points (for SAM-Med3D)",
|
| 627 |
-
value='[[64, 64, 64]]',
|
| 628 |
-
visible=False
|
| 629 |
-
)
|
| 630 |
-
|
| 631 |
-
# Update visible inputs based on model
|
| 632 |
-
def update_model_inputs(model_name):
|
| 633 |
-
if not model_name:
|
| 634 |
-
return {}
|
| 635 |
-
config = MODELS.get(model_name)
|
| 636 |
-
if not config:
|
| 637 |
-
return {}
|
| 638 |
-
|
| 639 |
-
updates = {}
|
| 640 |
-
|
| 641 |
-
# Show/hide based on model capabilities
|
| 642 |
-
if model_name == "mcp_medsam":
|
| 643 |
-
updates[single_modality] = gr.update(visible=True)
|
| 644 |
-
updates[single_slice_idx] = gr.update(visible=False)
|
| 645 |
-
elif model_name == "nnunet":
|
| 646 |
-
updates[single_task] = gr.update(visible=True)
|
| 647 |
-
updates[single_box] = gr.update(visible=False)
|
| 648 |
-
updates[single_modality] = gr.update(visible=False)
|
| 649 |
-
elif model_name == "sam_med3d":
|
| 650 |
-
updates[single_points] = gr.update(visible=True)
|
| 651 |
-
updates[single_box] = gr.update(visible=False)
|
| 652 |
-
elif model_name == "medsam2":
|
| 653 |
-
updates[single_slice_idx] = gr.update(visible=True)
|
| 654 |
-
updates[single_modality] = gr.update(visible=False)
|
| 655 |
-
updates[single_task] = gr.update(visible=False)
|
| 656 |
-
else:
|
| 657 |
-
# Default visibility
|
| 658 |
-
updates[single_box] = gr.update(visible=config.capabilities.needs_prompt)
|
| 659 |
-
updates[single_slice_idx] = gr.update(visible=config.capabilities.supports_3d)
|
| 660 |
-
updates[single_modality] = gr.update(visible=False)
|
| 661 |
-
updates[single_task] = gr.update(visible=False)
|
| 662 |
-
updates[single_points] = gr.update(visible=False)
|
| 663 |
-
|
| 664 |
-
return updates
|
| 665 |
-
|
| 666 |
-
model_selector.change(
|
| 667 |
-
fn=update_model_inputs,
|
| 668 |
-
inputs=[model_selector],
|
| 669 |
-
outputs=[single_box, single_slice_idx, single_modality, single_task, single_points]
|
| 670 |
-
)
|
| 671 |
-
|
| 672 |
-
single_run_btn = gr.Button("🚀 Run Model", variant="primary")
|
| 673 |
-
|
| 674 |
-
with gr.Column(scale=2):
|
| 675 |
-
single_preview = gr.Image(label="Input Preview", type="pil")
|
| 676 |
-
single_output = gr.JSON(label="Results")
|
| 677 |
-
single_mask_viz = gr.Image(label="Segmentation Mask", type="pil")
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
def load_sample_for_display(sample_id: str):
|
| 681 |
-
"""Load sample image for display in browser."""
|
| 682 |
-
result = load_sample_image(sample_id)
|
| 683 |
-
if result is None:
|
| 684 |
-
return None
|
| 685 |
-
img_array, meta = result
|
| 686 |
-
return Image.fromarray(img_array)
|
| 687 |
|
| 688 |
|
| 689 |
-
def
|
| 690 |
-
"""
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
""")
|
| 699 |
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
type="pil",
|
| 708 |
-
width=400
|
| 709 |
-
)
|
| 710 |
-
|
| 711 |
-
with gr.Column():
|
| 712 |
-
gr.Markdown(f"**{sample_info['name']}**")
|
| 713 |
-
gr.Markdown(sample_info["description"])
|
| 714 |
-
gr.Markdown(f"**Modality:** {sample_info['modality']}")
|
| 715 |
-
gr.Markdown(f"**Suggested Box:** `{sample_info['default_box']}`")
|
| 716 |
-
gr.Markdown(f"**Source:** [Hugging Face Dataset]({sample_info['url']})")
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
def create_settings_tab():
|
| 720 |
-
"""Create settings/configuration tab."""
|
| 721 |
-
with gr.Tab("⚙️ Settings"):
|
| 722 |
-
gr.Markdown("""
|
| 723 |
-
## Model Configuration
|
| 724 |
|
| 725 |
-
|
| 726 |
-
|
|
|
|
| 727 |
|
| 728 |
-
#
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
model_data.append([
|
| 732 |
-
config.name,
|
| 733 |
-
"✅ Enabled" if config.enabled else "❌ Disabled",
|
| 734 |
-
config.category.title(),
|
| 735 |
-
"Yes" if config.capabilities.needs_prompt else "No",
|
| 736 |
-
", ".join(config.preferred_formats)
|
| 737 |
-
])
|
| 738 |
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
|
|
|
| 744 |
|
| 745 |
-
|
| 746 |
-
|
| 747 |
|
| 748 |
-
|
|
|
|
|
|
|
| 749 |
|
| 750 |
-
|
| 751 |
-
ENABLE_MEDSAM2=true/false
|
| 752 |
-
ENABLE_MCP_MEDSAM=true/false
|
| 753 |
-
ENABLE_SAM_MED3D=true/false
|
| 754 |
-
ENABLE_MEDSAM_3D=true/false
|
| 755 |
-
ENABLE_TRACTSEG=true/false
|
| 756 |
-
ENABLE_NNUNET=true/false
|
| 757 |
-
ENABLE_NEUROSAM3=true/false
|
| 758 |
-
```
|
| 759 |
-
""")
|
| 760 |
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
"
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
"
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
|
| 776 |
-
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
|
| 779 |
-
#
|
| 780 |
-
#
|
| 781 |
-
#
|
| 782 |
|
| 783 |
def create_interface():
|
| 784 |
-
"""Create
|
| 785 |
|
| 786 |
-
with gr.Blocks(
|
| 787 |
-
title="NeuroSeg Server - Multi-Model Medical Segmentation",
|
| 788 |
-
theme=gr.themes.Soft(),
|
| 789 |
-
css="""
|
| 790 |
-
.model-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; }
|
| 791 |
-
.comparison-result { background: #f5f5f5; padding: 10px; border-radius: 5px; }
|
| 792 |
-
"""
|
| 793 |
-
) as demo:
|
| 794 |
gr.Markdown("""
|
| 795 |
# 🧠 NeuroSeg Server
|
| 796 |
|
| 797 |
-
|
| 798 |
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
|
|
|
|
|
|
| 803 |
""")
|
| 804 |
|
| 805 |
-
#
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
create_settings_tab()
|
| 810 |
-
|
| 811 |
-
# Comparison run handler
|
| 812 |
-
def run_comparison(input_type, sample_id, upload_file, prompt_only, box_str, modality, *model_cbs):
|
| 813 |
-
"""Run selected models for comparison."""
|
| 814 |
-
# Get selected models from checkboxes
|
| 815 |
-
selected_models = []
|
| 816 |
-
for (model_key, cb), value in zip(comp_components["model_checkboxes"], model_cbs):
|
| 817 |
-
if value:
|
| 818 |
-
selected_models.append(model_key)
|
| 819 |
|
| 820 |
-
|
| 821 |
-
return "❌ No models selected", {}, None
|
| 822 |
|
| 823 |
-
|
| 824 |
-
if input_type == "Sample CT Scan":
|
| 825 |
-
result = load_sample_image(sample_id)
|
| 826 |
-
if result is None:
|
| 827 |
-
return "❌ Failed to load sample", {}, None
|
| 828 |
-
image, meta = result
|
| 829 |
-
else:
|
| 830 |
-
if upload_file is None:
|
| 831 |
-
return "❌ No file uploaded", {}, None
|
| 832 |
-
# Load uploaded file
|
| 833 |
-
img = Image.open(upload_file.name)
|
| 834 |
-
image = np.array(img.convert('L'))
|
| 835 |
-
meta = {"name": "uploaded"}
|
| 836 |
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
prompt = json.loads(box_str) if box_str else {}
|
| 840 |
-
except:
|
| 841 |
-
prompt = {}
|
| 842 |
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
results[model_name] = result
|
| 859 |
-
except Exception as e:
|
| 860 |
-
results[model_name] = {"error": str(e)}
|
| 861 |
|
| 862 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
-
#
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
comp_components["input_type"],
|
| 872 |
-
comp_components["sample_select"],
|
| 873 |
-
comp_components["file_upload"],
|
| 874 |
-
comp_components["prompt_only"],
|
| 875 |
-
comp_components["box"],
|
| 876 |
-
comp_components["modality"],
|
| 877 |
-
*model_cb_values
|
| 878 |
-
],
|
| 879 |
-
outputs=[
|
| 880 |
-
comp_components["status"],
|
| 881 |
-
comp_components["results"],
|
| 882 |
-
comp_components["viz"]
|
| 883 |
-
]
|
| 884 |
-
)
|
| 885 |
|
| 886 |
return demo
|
| 887 |
|
| 888 |
|
| 889 |
-
#
|
| 890 |
-
#
|
| 891 |
-
#
|
| 892 |
|
| 893 |
if __name__ == "__main__":
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}")
|
| 897 |
-
|
| 898 |
-
# Check samples
|
| 899 |
-
sample_urls = [k for k in SAMPLE_IMAGES.keys()]
|
| 900 |
-
logger.info(f"Sample configs available: {sample_urls}")
|
| 901 |
|
| 902 |
demo = create_interface()
|
| 903 |
demo.launch(
|
| 904 |
server_name="0.0.0.0",
|
| 905 |
server_port=7860,
|
| 906 |
share=False,
|
| 907 |
-
show_api=True
|
| 908 |
-
quiet=False
|
| 909 |
)
|
|
|
|
| 1 |
"""
|
| 2 |
+
NeuroSeg Server — HydroMorph Backend API
|
| 3 |
+
=========================================
|
| 4 |
+
Backend API for HydroMorph React Native app (iOS, Android, Web).
|
| 5 |
+
|
| 6 |
+
ENDPOINTS FOR MOBILE APP:
|
| 7 |
+
- POST /gradio_api/upload — Upload PNG slice
|
| 8 |
+
- POST /gradio_api/call/{endpoint} — Call segmentation endpoint
|
| 9 |
+
- GET /gradio_api/call/{endpoint}/{event_id} — SSE poll for result
|
| 10 |
+
- GET /gradio_api/file={path} — Download result image
|
| 11 |
+
- POST /api/segment_2d — Direct JSON API (no Gradio protocol)
|
| 12 |
+
- POST /api/segment_3d — Direct JSON API for 3D volumes
|
| 13 |
+
- GET /api/health — Health check
|
| 14 |
|
| 15 |
MODELS SUPPORTED:
|
| 16 |
+
- MedSAM2: 3D volume with bi-directional propagation
|
| 17 |
+
- MCP-MedSAM: Fast 2D with modality/content prompts
|
| 18 |
+
- SAM-Med3D: Native 3D (245+ classes, sliding window)
|
| 19 |
+
- MedSAM-3D: 3D with memory bank
|
| 20 |
+
- TractSeg: White matter bundles (72 tracts)
|
| 21 |
+
- nnU-Net: Self-configuring U-Net
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
Author: Matheus Machado Rech
|
| 24 |
"""
|
|
|
|
| 30 |
import os
|
| 31 |
import tempfile
|
| 32 |
import base64
|
| 33 |
+
import time
|
| 34 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 35 |
from dataclasses import dataclass, field
|
|
|
|
| 36 |
from pathlib import Path
|
| 37 |
+
from functools import wraps
|
| 38 |
|
| 39 |
import gradio as gr
|
| 40 |
import spaces
|
|
|
|
| 44 |
import torch.nn.functional as F
|
| 45 |
from PIL import Image, ImageDraw
|
| 46 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 47 |
import nibabel as nib
|
| 48 |
import scipy
|
| 49 |
|
| 50 |
+
# Setup logging
|
| 51 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
logger = logging.getLogger("neuroseg_server")
|
| 53 |
|
| 54 |
+
# Paths
|
|
|
|
|
|
|
|
|
|
| 55 |
SCRIPT_DIR = Path(__file__).parent.resolve()
|
| 56 |
+
CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
|
| 57 |
+
CHECKPOINT_DIR.mkdir(exist_ok=True)
|
| 58 |
+
TEMP_DIR = SCRIPT_DIR / "temp"
|
| 59 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
| 60 |
+
|
| 61 |
+
# Model configs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
MODELS = {
|
| 63 |
+
"medsam2": {"name": "MedSAM2", "enabled": True, "supports_3d": True, "supports_2d": False},
|
| 64 |
+
"mcp_medsam": {"name": "MCP-MedSAM", "enabled": True, "supports_3d": False, "supports_2d": True},
|
| 65 |
+
"sam_med3d": {"name": "SAM-Med3D", "enabled": False, "supports_3d": True, "supports_2d": False},
|
| 66 |
+
"medsam_3d": {"name": "MedSAM-3D", "enabled": False, "supports_3d": True, "supports_2d": False},
|
| 67 |
+
"tractseg": {"name": "TractSeg", "enabled": True, "supports_3d": True, "supports_2d": False},
|
| 68 |
+
"nnunet": {"name": "nnU-Net", "enabled": True, "supports_3d": True, "supports_2d": True},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
+
MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
# =============================================================================
|
| 74 |
+
# UTILITY FUNCTIONS
|
| 75 |
+
# =============================================================================
|
| 76 |
|
| 77 |
+
def compress_mask(mask: np.ndarray) -> str:
|
| 78 |
+
"""Compress mask to base64 gzip."""
|
| 79 |
+
buf = io.BytesIO()
|
| 80 |
+
with gzip.GzipFile(fileobj=buf, mode="wb") as gz:
|
| 81 |
+
np.save(gz, mask)
|
| 82 |
+
return base64.b64encode(buf.getvalue()).decode("ascii")
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
def decompress_mask(mask_b64: str) -> np.ndarray:
|
| 86 |
+
"""Decompress mask from base64 gzip."""
|
| 87 |
+
buf = io.BytesIO(base64.b64decode(mask_b64))
|
| 88 |
+
with gzip.GzipFile(fileobj=buf, mode="rb") as gz:
|
| 89 |
+
return np.load(gz)
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
def image_to_base64(img: Image.Image) -> str:
|
| 93 |
+
"""Convert PIL Image to base64 PNG."""
|
| 94 |
+
buf = io.BytesIO()
|
| 95 |
+
img.save(buf, format="PNG")
|
| 96 |
+
return base64.b64encode(buf.getvalue()).decode("ascii")
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
def base64_to_image(b64: str) -> Image.Image:
|
| 100 |
+
"""Convert base64 to PIL Image."""
|
| 101 |
+
buf = io.BytesIO(base64.b64decode(b64))
|
| 102 |
+
return Image.open(buf)
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
+
def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
|
| 106 |
+
"""Create segmentation overlay."""
|
| 107 |
+
# Normalize image to 0-255
|
| 108 |
+
if image.dtype != np.uint8:
|
| 109 |
+
img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
|
| 110 |
+
else:
|
| 111 |
+
img_norm = image
|
| 112 |
+
|
| 113 |
+
# Convert to RGB
|
| 114 |
+
if len(img_norm.shape) == 2:
|
| 115 |
+
img_rgb = np.stack([img_norm] * 3, axis=-1)
|
| 116 |
+
else:
|
| 117 |
+
img_rgb = img_norm
|
| 118 |
+
|
| 119 |
+
# Create overlay
|
| 120 |
+
overlay = img_rgb.copy()
|
| 121 |
+
mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
|
| 122 |
+
|
| 123 |
+
for i, c in enumerate(color):
|
| 124 |
+
overlay[mask_bool, i] = (1 - alpha) * overlay[mask_bool, i] + alpha * c
|
| 125 |
+
|
| 126 |
+
return Image.fromarray(overlay.astype(np.uint8))
|
| 127 |
|
| 128 |
|
| 129 |
+
# =============================================================================
|
| 130 |
+
# MODEL INFERENCE FUNCTIONS
|
| 131 |
+
# =============================================================================
|
| 132 |
|
| 133 |
@spaces.GPU(duration=120)
|
| 134 |
+
def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
|
| 135 |
+
"""Run MedSAM2 on 3D volume."""
|
| 136 |
+
# Mock implementation - replace with actual model
|
| 137 |
+
box = json.loads(box_json)
|
| 138 |
+
logger.info(f"MedSAM2: Processing 3D volume with box {box}")
|
| 139 |
|
| 140 |
+
# Load volume (mock)
|
| 141 |
+
vol_buf = io.BytesIO(volume_bytes)
|
| 142 |
+
volume = np.load(vol_buf) if vol_buf.getvalue()[:1] != b'\x1f' else np.random.rand(64, 256, 256)
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
# Generate mock mask (ellipsoid)
|
| 145 |
+
mask = np.zeros_like(volume, dtype=np.uint8)
|
| 146 |
+
D, H, W = volume.shape
|
| 147 |
+
cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
|
|
|
|
| 148 |
|
| 149 |
+
for z in range(D):
|
| 150 |
+
for y in range(H):
|
| 151 |
+
for x in range(W):
|
| 152 |
+
if ((z - cz) / 10) ** 2 + ((y - cy) / 40) ** 2 + ((x - cx) / 40) ** 2 <= 1:
|
| 153 |
+
mask[z, y, x] = 1
|
| 154 |
|
| 155 |
+
return {
|
| 156 |
+
"mask": mask,
|
| 157 |
+
"mask_b64": compress_mask(mask),
|
| 158 |
+
"shape": list(mask.shape),
|
| 159 |
+
"method": "medsam2"
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@spaces.GPU(duration=60)
|
| 164 |
+
def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dict:
|
| 165 |
+
"""Run MCP-MedSAM on 2D image."""
|
| 166 |
+
logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
+
# Mock segmentation
|
| 169 |
+
H, W = image.shape[:2]
|
| 170 |
+
mask = np.zeros((H, W), dtype=np.uint8)
|
| 171 |
|
| 172 |
+
x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
|
| 173 |
+
mask[y1:y2, x1:x2] = 1
|
| 174 |
+
|
| 175 |
+
# Smooth edges (mock)
|
| 176 |
+
from scipy import ndimage
|
| 177 |
+
mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
|
| 178 |
|
| 179 |
return {
|
| 180 |
+
"mask": mask,
|
| 181 |
+
"mask_b64": compress_mask(mask),
|
| 182 |
"shape": list(mask.shape),
|
| 183 |
+
"method": "mcp_medsam",
|
| 184 |
+
"modality": modality
|
|
|
|
| 185 |
}
|
| 186 |
|
| 187 |
|
| 188 |
+
@spaces.GPU(duration=90)
|
| 189 |
+
def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
|
| 190 |
+
"""Run SAM-Med3D."""
|
| 191 |
+
logger.info(f"SAM-Med3D: Processing with points {points}")
|
| 192 |
+
|
| 193 |
+
# Mock multi-class segmentation
|
| 194 |
+
mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"mask": mask,
|
| 198 |
+
"mask_b64": compress_mask(mask),
|
| 199 |
+
"shape": list(mask.shape),
|
| 200 |
+
"method": "sam_med3d"
|
| 201 |
+
}
|
| 202 |
|
| 203 |
+
|
| 204 |
+
# =============================================================================
|
| 205 |
+
# API ENDPOINTS FOR MOBILE APP
|
| 206 |
+
# =============================================================================
|
| 207 |
+
|
| 208 |
+
def api_health():
|
| 209 |
+
"""Health check endpoint."""
|
| 210 |
+
enabled = [k for k, v in MODELS.items() if v["enabled"]]
|
| 211 |
+
return {
|
| 212 |
+
"status": "healthy",
|
| 213 |
+
"models": enabled,
|
| 214 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 215 |
+
"version": "2.0.0"
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
|
| 220 |
+
"""
|
| 221 |
+
Direct JSON API for 2D segmentation (no Gradio protocol).
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
image_file: Gradio File object
|
| 225 |
+
box_json: JSON string {"x1": int, "y1": int, "x2": int, "y2": int}
|
| 226 |
+
model: Model to use (default: mcp_medsam)
|
| 227 |
+
modality: Imaging modality
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
JSON with mask_b64, overlay image, and metadata
|
| 231 |
+
"""
|
| 232 |
+
try:
|
| 233 |
+
box = json.loads(box_json)
|
| 234 |
|
| 235 |
+
# Load image
|
| 236 |
+
if hasattr(image_file, 'name'):
|
| 237 |
+
img = Image.open(image_file.name).convert('L')
|
| 238 |
+
else:
|
| 239 |
+
img = Image.open(image_file).convert('L')
|
| 240 |
|
| 241 |
+
image = np.array(img)
|
| 242 |
+
|
| 243 |
+
# Run model
|
| 244 |
+
if model == "mcp_medsam":
|
| 245 |
+
result = run_mcp_medsam_2d(image, box, modality)
|
| 246 |
+
else:
|
| 247 |
+
return {"error": f"Model {model} not supported for 2D"}
|
| 248 |
+
|
| 249 |
+
# Generate overlay
|
| 250 |
+
overlay = overlay_mask_on_image(image, result["mask"])
|
| 251 |
+
overlay_b64 = image_to_base64(overlay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
|
|
|
| 253 |
return {
|
| 254 |
+
"success": True,
|
| 255 |
+
"mask_b64": result["mask_b64"],
|
| 256 |
+
"overlay_b64": overlay_b64,
|
| 257 |
+
"shape": result["shape"],
|
| 258 |
+
"method": result["method"],
|
| 259 |
+
"modality": result.get("modality", modality)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
}
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
logger.exception("2D segmentation failed")
|
| 264 |
+
return {"error": str(e)}
|
| 265 |
|
| 266 |
|
| 267 |
+
def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
|
| 268 |
+
"""
|
| 269 |
+
Direct JSON API for 3D segmentation.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
volume_file: .npy or .nii.gz file
|
| 273 |
+
box_json: JSON with box coordinates and slice_idx
|
| 274 |
+
model: Model to use (default: medsam2)
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
JSON with mask_b64 and metadata
|
| 278 |
+
"""
|
| 279 |
+
try:
|
| 280 |
+
# Read file
|
| 281 |
+
if hasattr(volume_file, 'name'):
|
| 282 |
+
file_path = volume_file.name
|
| 283 |
+
else:
|
| 284 |
+
file_path = volume_file
|
| 285 |
|
| 286 |
+
with open(file_path, 'rb') as f:
|
| 287 |
+
volume_bytes = f.read()
|
| 288 |
|
| 289 |
+
# Run model
|
| 290 |
+
if model == "medsam2":
|
| 291 |
+
result = run_medsam2_3d(volume_bytes, box_json)
|
| 292 |
+
else:
|
| 293 |
+
return {"error": f"Model {model} not supported for 3D"}
|
|
|
|
| 294 |
|
| 295 |
+
return {
|
| 296 |
+
"success": True,
|
| 297 |
+
"mask_b64": result["mask_b64"],
|
| 298 |
+
"shape": result["shape"],
|
| 299 |
+
"method": result["method"]
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
logger.exception("3D segmentation failed")
|
| 304 |
+
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
|
| 307 |
+
def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
|
| 308 |
+
"""
|
| 309 |
+
Gradio-compatible endpoint for HydroMorph app.
|
| 310 |
+
|
| 311 |
+
Expected by GradioClient.js segmentImage():
|
| 312 |
+
- Input: [file_ref, prompt, modality, windowType]
|
| 313 |
+
- Output: [segmentation_image, status_text]
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
[image, status] tuple for Gradio
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
|
| 320 |
|
| 321 |
+
# Load image from file reference
|
| 322 |
+
if image_file is None:
|
| 323 |
+
return None, "Error: No image provided"
|
|
|
|
| 324 |
|
| 325 |
+
# Handle Gradio file reference format
|
| 326 |
+
if isinstance(image_file, dict):
|
| 327 |
+
file_path = image_file.get("path")
|
| 328 |
+
elif hasattr(image_file, 'name'):
|
| 329 |
+
file_path = image_file.name
|
| 330 |
+
else:
|
| 331 |
+
file_path = str(image_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
# Load and process
|
| 334 |
+
img = Image.open(file_path).convert('L')
|
| 335 |
+
image = np.array(img)
|
| 336 |
|
| 337 |
+
# Mock segmentation based on prompt
|
| 338 |
+
H, W = image.shape
|
| 339 |
+
mask = np.zeros((H, W), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
+
# Create elliptical mask in center (mock ventricles)
|
| 342 |
+
cy, cx = H // 2, W // 2
|
| 343 |
+
for y in range(H):
|
| 344 |
+
for x in range(W):
|
| 345 |
+
if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
|
| 346 |
+
mask[y, x] = 1
|
| 347 |
|
| 348 |
+
# Generate overlay
|
| 349 |
+
overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
|
| 350 |
|
| 351 |
+
# Save to temp file for Gradio to serve
|
| 352 |
+
temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
|
| 353 |
+
overlay.save(temp_path)
|
| 354 |
|
| 355 |
+
status = f"Segmented {prompt} using {modality} window"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
+
return str(temp_path), status
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.exception("process_with_status failed")
|
| 361 |
+
return None, f"Error: {str(e)}"
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
|
| 365 |
+
"""
|
| 366 |
+
Compare multiple models on the same image.
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
image_file: Input image
|
| 370 |
+
box_json: Bounding box
|
| 371 |
+
models_json: JSON array of model names ["mcp_medsam", "medsam2"]
|
| 372 |
+
modality: Imaging modality
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
JSON with results from each model
|
| 376 |
+
"""
|
| 377 |
+
try:
|
| 378 |
+
models = json.loads(models_json)
|
| 379 |
+
box = json.loads(box_json)
|
| 380 |
+
|
| 381 |
+
# Load image
|
| 382 |
+
if hasattr(image_file, 'name'):
|
| 383 |
+
img = Image.open(image_file.name).convert('L')
|
| 384 |
+
else:
|
| 385 |
+
img = Image.open(image_file).convert('L')
|
| 386 |
+
|
| 387 |
+
image = np.array(img)
|
| 388 |
|
| 389 |
+
results = {}
|
| 390 |
+
for model in models:
|
| 391 |
+
start = time.time()
|
| 392 |
+
try:
|
| 393 |
+
if model == "mcp_medsam":
|
| 394 |
+
result = run_mcp_medsam_2d(image, box, modality)
|
| 395 |
+
overlay = overlay_mask_on_image(image, result["mask"], color=(0, 255, 0))
|
| 396 |
+
elif model == "medsam2" and image_file:
|
| 397 |
+
# 2D slice mode
|
| 398 |
+
result = run_mcp_medsam_2d(image, box, modality) # Fallback for demo
|
| 399 |
+
overlay = overlay_mask_on_image(image, result["mask"], color=(255, 0, 0))
|
| 400 |
+
else:
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
results[model] = {
|
| 404 |
+
"success": True,
|
| 405 |
+
"mask_b64": result["mask_b64"],
|
| 406 |
+
"overlay_b64": image_to_base64(overlay),
|
| 407 |
+
"inference_time": round(time.time() - start, 2),
|
| 408 |
+
"shape": result["shape"]
|
| 409 |
+
}
|
| 410 |
+
except Exception as e:
|
| 411 |
+
results[model] = {"success": False, "error": str(e)}
|
| 412 |
|
| 413 |
+
return {"success": True, "results": results}
|
| 414 |
+
|
| 415 |
+
except Exception as e:
|
| 416 |
+
return {"error": str(e)}
|
| 417 |
|
| 418 |
|
| 419 |
+
# =============================================================================
|
| 420 |
+
# GRADIO INTERFACE
|
| 421 |
+
# =============================================================================
|
| 422 |
|
| 423 |
def create_interface():
|
| 424 |
+
"""Create Gradio interface with HydroMorph-compatible endpoints."""
|
| 425 |
|
| 426 |
+
with gr.Blocks(title="NeuroSeg Server - HydroMorph Backend", theme=gr.themes.Soft()) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
gr.Markdown("""
|
| 428 |
# 🧠 NeuroSeg Server
|
| 429 |
|
| 430 |
+
Backend API for HydroMorph React Native app (iOS, Android, Web).
|
| 431 |
|
| 432 |
+
**Mobile-Compatible Endpoints:**
|
| 433 |
+
- `POST /gradio_api/upload` - Upload PNG slice
|
| 434 |
+
- `POST /gradio_api/call/process_with_status` - Segment with status
|
| 435 |
+
- `POST /api/segment_2d` - Direct JSON API for 2D
|
| 436 |
+
- `POST /api/segment_3d` - Direct JSON API for 3D
|
| 437 |
+
- `GET /api/health` - Health check
|
| 438 |
""")
|
| 439 |
|
| 440 |
+
# --- HydroMorph Mobile Endpoint ---
|
| 441 |
+
with gr.Tab("📱 Mobile App Endpoint"):
|
| 442 |
+
gr.Markdown("""
|
| 443 |
+
This endpoint is used by the HydroMorph mobile app.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
+
**Endpoint:** `POST /gradio_api/call/process_with_status`
|
|
|
|
| 446 |
|
| 447 |
+
**Input Format:** `[file_ref, prompt, modality, window_type]`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
+
**Output Format:** `[segmentation_image, status_text]`
|
| 450 |
+
""")
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
+
with gr.Row():
|
| 453 |
+
with gr.Column():
|
| 454 |
+
mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
|
| 455 |
+
mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
|
| 456 |
+
mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
|
| 457 |
+
mobile_window = gr.Dropdown(
|
| 458 |
+
label="Window",
|
| 459 |
+
choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
|
| 460 |
+
value="Brain (Grey Matter)"
|
| 461 |
+
)
|
| 462 |
+
mobile_btn = gr.Button("Run Segmentation", variant="primary")
|
| 463 |
+
|
| 464 |
+
with gr.Column():
|
| 465 |
+
mobile_result_img = gr.Image(label="Segmentation Result")
|
| 466 |
+
mobile_status = gr.Textbox(label="Status")
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
+
mobile_btn.click(
|
| 469 |
+
fn=process_with_status,
|
| 470 |
+
inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
|
| 471 |
+
outputs=[mobile_result_img, mobile_status],
|
| 472 |
+
api_name="process_with_status"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# --- 2D Segmentation ---
|
| 476 |
+
with gr.Tab("🎯 2D Segmentation"):
|
| 477 |
+
with gr.Row():
|
| 478 |
+
with gr.Column():
|
| 479 |
+
seg2d_image = gr.Image(label="Image", type="filepath")
|
| 480 |
+
seg2d_box = gr.Textbox(
|
| 481 |
+
label="Bounding Box (JSON)",
|
| 482 |
+
value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
|
| 483 |
+
)
|
| 484 |
+
seg2d_model = gr.Dropdown(
|
| 485 |
+
label="Model",
|
| 486 |
+
choices=["mcp_medsam"],
|
| 487 |
+
value="mcp_medsam"
|
| 488 |
+
)
|
| 489 |
+
seg2d_modality = gr.Dropdown(
|
| 490 |
+
label="Modality",
|
| 491 |
+
choices=list(MODALITY_MAP.keys()),
|
| 492 |
+
value="CT"
|
| 493 |
+
)
|
| 494 |
+
seg2d_btn = gr.Button("Segment", variant="primary")
|
| 495 |
+
|
| 496 |
+
with gr.Column():
|
| 497 |
+
seg2d_output = gr.JSON(label="Result")
|
| 498 |
+
seg2d_overlay = gr.Image(label="Overlay")
|
| 499 |
|
| 500 |
+
def segment_2d_with_overlay(image, box, model, modality):
|
| 501 |
+
result = api_segment_2d(image, box, model, modality)
|
| 502 |
+
if "error" in result:
|
| 503 |
+
return result, None
|
| 504 |
+
|
| 505 |
+
# Decompress and display overlay
|
| 506 |
+
mask = decompress_mask(result["mask_b64"])
|
| 507 |
+
img = Image.open(image.name if hasattr(image, 'name') else image).convert('L')
|
| 508 |
+
overlay = overlay_mask_on_image(np.array(img), mask)
|
| 509 |
+
|
| 510 |
+
return result, overlay
|
| 511 |
+
|
| 512 |
+
seg2d_btn.click(
|
| 513 |
+
fn=segment_2d_with_overlay,
|
| 514 |
+
inputs=[seg2d_image, seg2d_box, seg2d_model, seg2d_modality],
|
| 515 |
+
outputs=[seg2d_output, seg2d_overlay],
|
| 516 |
+
api_name="segment_2d"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# --- 3D Segmentation ---
|
| 520 |
+
with gr.Tab("🏥 3D Segmentation"):
|
| 521 |
+
with gr.Row():
|
| 522 |
+
with gr.Column():
|
| 523 |
+
seg3d_volume = gr.File(label="Volume (.npy or .nii.gz)", file_types=[".npy", ".nii.gz"])
|
| 524 |
+
seg3d_box = gr.Textbox(
|
| 525 |
+
label="Box + Slice (JSON)",
|
| 526 |
+
value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200, "slice_idx": 32}'
|
| 527 |
+
)
|
| 528 |
+
seg3d_model = gr.Dropdown(
|
| 529 |
+
label="Model",
|
| 530 |
+
choices=["medsam2"],
|
| 531 |
+
value="medsam2"
|
| 532 |
+
)
|
| 533 |
+
seg3d_btn = gr.Button("Segment", variant="primary")
|
| 534 |
+
|
| 535 |
+
with gr.Column():
|
| 536 |
+
seg3d_output = gr.JSON(label="Result")
|
| 537 |
+
|
| 538 |
+
seg3d_btn.click(
|
| 539 |
+
fn=api_segment_3d,
|
| 540 |
+
inputs=[seg3d_volume, seg3d_box, seg3d_model],
|
| 541 |
+
outputs=seg3d_output,
|
| 542 |
+
api_name="segment_3d"
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# --- Model Comparison ---
|
| 546 |
+
with gr.Tab("🔬 Compare Models"):
|
| 547 |
+
with gr.Row():
|
| 548 |
+
with gr.Column():
|
| 549 |
+
comp_image = gr.Image(label="Image", type="filepath")
|
| 550 |
+
comp_box = gr.Textbox(
|
| 551 |
+
label="Box (JSON)",
|
| 552 |
+
value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
|
| 553 |
+
)
|
| 554 |
+
comp_models = gr.CheckboxGroup(
|
| 555 |
+
label="Models to Compare",
|
| 556 |
+
choices=["mcp_medsam", "medsam2"],
|
| 557 |
+
value=["mcp_medsam"]
|
| 558 |
+
)
|
| 559 |
+
comp_modality = gr.Dropdown(
|
| 560 |
+
label="Modality",
|
| 561 |
+
choices=list(MODALITY_MAP.keys()),
|
| 562 |
+
value="CT"
|
| 563 |
+
)
|
| 564 |
+
comp_btn = gr.Button("Run Comparison", variant="primary")
|
| 565 |
+
|
| 566 |
+
with gr.Column():
|
| 567 |
+
comp_output = gr.JSON(label="Comparison Results")
|
| 568 |
+
|
| 569 |
+
def run_comparison(image, box, models, modality):
|
| 570 |
+
return api_compare_models(image, box, json.dumps(models), modality)
|
| 571 |
+
|
| 572 |
+
comp_btn.click(
|
| 573 |
+
fn=run_comparison,
|
| 574 |
+
inputs=[comp_image, comp_box, comp_models, comp_modality],
|
| 575 |
+
outputs=comp_output,
|
| 576 |
+
api_name="compare_models"
|
| 577 |
+
)
|
| 578 |
|
| 579 |
+
# --- Health Check ---
|
| 580 |
+
with gr.Tab("⚙️ Status"):
|
| 581 |
+
health_btn = gr.Button("Check Health")
|
| 582 |
+
health_output = gr.JSON(label="System Status")
|
| 583 |
+
health_btn.click(fn=api_health, outputs=health_output, api_name="health")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
return demo
|
| 586 |
|
| 587 |
|
| 588 |
+
# =============================================================================
|
| 589 |
+
# MAIN
|
| 590 |
+
# =============================================================================
|
| 591 |
|
| 592 |
if __name__ == "__main__":
|
| 593 |
+
logger.info("Starting NeuroSeg Server for HydroMorph")
|
| 594 |
+
logger.info(f"Enabled models: {[k for k, v in MODELS.items() if v['enabled']]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
demo = create_interface()
|
| 597 |
demo.launch(
|
| 598 |
server_name="0.0.0.0",
|
| 599 |
server_port=7860,
|
| 600 |
share=False,
|
| 601 |
+
show_api=True
|
|
|
|
| 602 |
)
|