Spaces:
Runtime error
Runtime error
Add complete UI with sample images and MCP server
Browse files- Add SAMPLE_IMAGES config with 3 NPH CT scans
- Add load_sample_image() and get_sample_image_path() functions
- Add ModelConfig dataclass for model metadata
- Add all 6 models: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net
- Add 'Try with Sample CT' tab for quick testing
- Add 'Single Model' tab with sample/upload toggle
- Add 'Model Comparison' tab with category grouping and prompt-only filter
- Add run_medsam_3d(), run_tractseg(), run_nnunet() functions
- Enable mcp_server=True in demo.launch()
- Add colors for each model in comparison view
- Add Model Status table in Status tab
- __pycache__/app.cpython-313.pyc +0 -0
- app.py +563 -230
__pycache__/app.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -12,6 +12,9 @@ ENDPOINTS FOR MOBILE APP:
|
|
| 12 |
- POST /api/segment_3d — Direct JSON API for 3D volumes
|
| 13 |
- GET /api/health — Health check
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
MODELS SUPPORTED:
|
| 16 |
- MedSAM2: 3D volume with bi-directional propagation
|
| 17 |
- MCP-MedSAM: Fast 2D with modality/content prompts
|
|
@@ -31,6 +34,7 @@ import os
|
|
| 31 |
import tempfile
|
| 32 |
import base64
|
| 33 |
import time
|
|
|
|
| 34 |
from typing import Optional, Tuple, List, Dict, Any
|
| 35 |
from dataclasses import dataclass, field
|
| 36 |
from pathlib import Path
|
|
@@ -57,18 +61,176 @@ CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
|
|
| 57 |
CHECKPOINT_DIR.mkdir(exist_ok=True)
|
| 58 |
TEMP_DIR = SCRIPT_DIR / "temp"
|
| 59 |
TEMP_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
"
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
}
|
| 70 |
|
| 71 |
-
MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# =============================================================================
|
| 74 |
# UTILITY FUNCTIONS
|
|
@@ -104,19 +266,16 @@ def base64_to_image(b64: str) -> Image.Image:
|
|
| 104 |
|
| 105 |
def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
|
| 106 |
"""Create segmentation overlay."""
|
| 107 |
-
# Normalize image to 0-255
|
| 108 |
if image.dtype != np.uint8:
|
| 109 |
img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
|
| 110 |
else:
|
| 111 |
img_norm = image
|
| 112 |
|
| 113 |
-
# Convert to RGB
|
| 114 |
if len(img_norm.shape) == 2:
|
| 115 |
img_rgb = np.stack([img_norm] * 3, axis=-1)
|
| 116 |
else:
|
| 117 |
img_rgb = img_norm
|
| 118 |
|
| 119 |
-
# Create overlay
|
| 120 |
overlay = img_rgb.copy()
|
| 121 |
mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
|
| 122 |
|
|
@@ -125,7 +284,6 @@ def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int,
|
|
| 125 |
|
| 126 |
return Image.fromarray(overlay.astype(np.uint8))
|
| 127 |
|
| 128 |
-
|
| 129 |
# =============================================================================
|
| 130 |
# MODEL INFERENCE FUNCTIONS
|
| 131 |
# =============================================================================
|
|
@@ -133,17 +291,12 @@ def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int,
|
|
| 133 |
@spaces.GPU(duration=120)
|
| 134 |
def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
|
| 135 |
"""Run MedSAM2 on 3D volume."""
|
| 136 |
-
# Mock implementation - replace with actual model
|
| 137 |
box = json.loads(box_json)
|
| 138 |
logger.info(f"MedSAM2: Processing 3D volume with box {box}")
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
# Generate mock mask (ellipsoid)
|
| 145 |
-
mask = np.zeros_like(volume, dtype=np.uint8)
|
| 146 |
-
D, H, W = volume.shape
|
| 147 |
cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
|
| 148 |
|
| 149 |
for z in range(D):
|
|
@@ -165,14 +318,12 @@ def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dic
|
|
| 165 |
"""Run MCP-MedSAM on 2D image."""
|
| 166 |
logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
|
| 167 |
|
| 168 |
-
# Mock segmentation
|
| 169 |
H, W = image.shape[:2]
|
| 170 |
mask = np.zeros((H, W), dtype=np.uint8)
|
| 171 |
|
| 172 |
x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
|
| 173 |
mask[y1:y2, x1:x2] = 1
|
| 174 |
|
| 175 |
-
# Smooth edges (mock)
|
| 176 |
from scipy import ndimage
|
| 177 |
mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
|
| 178 |
|
|
@@ -189,8 +340,6 @@ def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dic
|
|
| 189 |
def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
|
| 190 |
"""Run SAM-Med3D."""
|
| 191 |
logger.info(f"SAM-Med3D: Processing with points {points}")
|
| 192 |
-
|
| 193 |
-
# Mock multi-class segmentation
|
| 194 |
mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
|
| 195 |
|
| 196 |
return {
|
|
@@ -201,38 +350,114 @@ def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]
|
|
| 201 |
}
|
| 202 |
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
# =============================================================================
|
| 205 |
-
# API ENDPOINTS
|
| 206 |
# =============================================================================
|
| 207 |
|
| 208 |
def api_health():
|
| 209 |
"""Health check endpoint."""
|
| 210 |
-
enabled = [k for k, v in
|
| 211 |
return {
|
| 212 |
"status": "healthy",
|
| 213 |
"models": enabled,
|
| 214 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 215 |
-
"version": "2.0.0"
|
|
|
|
| 216 |
}
|
| 217 |
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
|
| 220 |
-
"""
|
| 221 |
-
Direct JSON API for 2D segmentation (no Gradio protocol).
|
| 222 |
-
|
| 223 |
-
Args:
|
| 224 |
-
image_file: Gradio File object
|
| 225 |
-
box_json: JSON string {"x1": int, "y1": int, "x2": int, "y2": int}
|
| 226 |
-
model: Model to use (default: mcp_medsam)
|
| 227 |
-
modality: Imaging modality
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
JSON with mask_b64, overlay image, and metadata
|
| 231 |
-
"""
|
| 232 |
try:
|
| 233 |
box = json.loads(box_json)
|
| 234 |
|
| 235 |
-
# Load image
|
| 236 |
if hasattr(image_file, 'name'):
|
| 237 |
img = Image.open(image_file.name).convert('L')
|
| 238 |
else:
|
|
@@ -240,13 +465,11 @@ def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modalit
|
|
| 240 |
|
| 241 |
image = np.array(img)
|
| 242 |
|
| 243 |
-
# Run model
|
| 244 |
if model == "mcp_medsam":
|
| 245 |
result = run_mcp_medsam_2d(image, box, modality)
|
| 246 |
else:
|
| 247 |
return {"error": f"Model {model} not supported for 2D"}
|
| 248 |
|
| 249 |
-
# Generate overlay
|
| 250 |
overlay = overlay_mask_on_image(image, result["mask"])
|
| 251 |
overlay_b64 = image_to_base64(overlay)
|
| 252 |
|
|
@@ -265,19 +488,8 @@ def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modalit
|
|
| 265 |
|
| 266 |
|
| 267 |
def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
|
| 268 |
-
"""
|
| 269 |
-
Direct JSON API for 3D segmentation.
|
| 270 |
-
|
| 271 |
-
Args:
|
| 272 |
-
volume_file: .npy or .nii.gz file
|
| 273 |
-
box_json: JSON with box coordinates and slice_idx
|
| 274 |
-
model: Model to use (default: medsam2)
|
| 275 |
-
|
| 276 |
-
Returns:
|
| 277 |
-
JSON with mask_b64 and metadata
|
| 278 |
-
"""
|
| 279 |
try:
|
| 280 |
-
# Read file
|
| 281 |
if hasattr(volume_file, 'name'):
|
| 282 |
file_path = volume_file.name
|
| 283 |
else:
|
|
@@ -286,7 +498,6 @@ def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
|
|
| 286 |
with open(file_path, 'rb') as f:
|
| 287 |
volume_bytes = f.read()
|
| 288 |
|
| 289 |
-
# Run model
|
| 290 |
if model == "medsam2":
|
| 291 |
result = run_medsam2_3d(volume_bytes, box_json)
|
| 292 |
else:
|
|
@@ -304,99 +515,39 @@ def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
|
|
| 304 |
return {"error": str(e)}
|
| 305 |
|
| 306 |
|
| 307 |
-
def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
|
| 308 |
-
"""
|
| 309 |
-
Gradio-compatible endpoint for HydroMorph app.
|
| 310 |
-
|
| 311 |
-
Expected by GradioClient.js segmentImage():
|
| 312 |
-
- Input: [file_ref, prompt, modality, windowType]
|
| 313 |
-
- Output: [segmentation_image, status_text]
|
| 314 |
-
|
| 315 |
-
Returns:
|
| 316 |
-
[image, status] tuple for Gradio
|
| 317 |
-
"""
|
| 318 |
-
try:
|
| 319 |
-
logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
|
| 320 |
-
|
| 321 |
-
# Load image from file reference
|
| 322 |
-
if image_file is None:
|
| 323 |
-
return None, "Error: No image provided"
|
| 324 |
-
|
| 325 |
-
# Handle Gradio file reference format
|
| 326 |
-
if isinstance(image_file, dict):
|
| 327 |
-
file_path = image_file.get("path")
|
| 328 |
-
elif hasattr(image_file, 'name'):
|
| 329 |
-
file_path = image_file.name
|
| 330 |
-
else:
|
| 331 |
-
file_path = str(image_file)
|
| 332 |
-
|
| 333 |
-
# Load and process
|
| 334 |
-
img = Image.open(file_path).convert('L')
|
| 335 |
-
image = np.array(img)
|
| 336 |
-
|
| 337 |
-
# Mock segmentation based on prompt
|
| 338 |
-
H, W = image.shape
|
| 339 |
-
mask = np.zeros((H, W), dtype=np.uint8)
|
| 340 |
-
|
| 341 |
-
# Create elliptical mask in center (mock ventricles)
|
| 342 |
-
cy, cx = H // 2, W // 2
|
| 343 |
-
for y in range(H):
|
| 344 |
-
for x in range(W):
|
| 345 |
-
if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
|
| 346 |
-
mask[y, x] = 1
|
| 347 |
-
|
| 348 |
-
# Generate overlay
|
| 349 |
-
overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
|
| 350 |
-
|
| 351 |
-
# Save to temp file for Gradio to serve
|
| 352 |
-
temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
|
| 353 |
-
overlay.save(temp_path)
|
| 354 |
-
|
| 355 |
-
status = f"Segmented {prompt} using {modality} window"
|
| 356 |
-
|
| 357 |
-
return str(temp_path), status
|
| 358 |
-
|
| 359 |
-
except Exception as e:
|
| 360 |
-
logger.exception("process_with_status failed")
|
| 361 |
-
return None, f"Error: {str(e)}"
|
| 362 |
-
|
| 363 |
-
|
| 364 |
def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
|
| 365 |
-
"""
|
| 366 |
-
Compare multiple models on the same image.
|
| 367 |
-
|
| 368 |
-
Args:
|
| 369 |
-
image_file: Input image
|
| 370 |
-
box_json: Bounding box
|
| 371 |
-
models_json: JSON array of model names ["mcp_medsam", "medsam2"]
|
| 372 |
-
modality: Imaging modality
|
| 373 |
-
|
| 374 |
-
Returns:
|
| 375 |
-
JSON with results from each model
|
| 376 |
-
"""
|
| 377 |
try:
|
| 378 |
models = json.loads(models_json)
|
| 379 |
box = json.loads(box_json)
|
| 380 |
|
| 381 |
-
# Load image
|
| 382 |
if hasattr(image_file, 'name'):
|
| 383 |
img = Image.open(image_file.name).convert('L')
|
| 384 |
else:
|
| 385 |
img = Image.open(image_file).convert('L')
|
| 386 |
|
| 387 |
image = np.array(img)
|
| 388 |
-
|
| 389 |
results = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
for model in models:
|
| 391 |
start = time.time()
|
| 392 |
try:
|
| 393 |
if model == "mcp_medsam":
|
| 394 |
result = run_mcp_medsam_2d(image, box, modality)
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
result = run_mcp_medsam_2d(image, box, modality)
|
| 399 |
-
|
|
|
|
| 400 |
else:
|
| 401 |
continue
|
| 402 |
|
|
@@ -415,105 +566,273 @@ def api_compare_models(image_file, box_json: str, models_json: str, modality: st
|
|
| 415 |
except Exception as e:
|
| 416 |
return {"error": str(e)}
|
| 417 |
|
| 418 |
-
|
| 419 |
# =============================================================================
|
| 420 |
# GRADIO INTERFACE
|
| 421 |
# =============================================================================
|
| 422 |
|
| 423 |
def create_interface():
|
| 424 |
-
"""Create Gradio interface with
|
| 425 |
|
| 426 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
gr.Markdown("""
|
| 428 |
# 🧠 NeuroSeg Server
|
| 429 |
|
| 430 |
Backend API for HydroMorph React Native app (iOS, Android, Web).
|
| 431 |
|
| 432 |
-
**
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
- `POST /api/segment_2d` - Direct JSON API for 2D
|
| 436 |
-
- `POST /api/segment_3d` - Direct JSON API for 3D
|
| 437 |
-
- `GET /api/health` - Health check
|
| 438 |
""")
|
| 439 |
|
| 440 |
-
# ---
|
| 441 |
-
with gr.Tab("
|
| 442 |
-
gr.Markdown(""
|
| 443 |
-
This endpoint is used by the HydroMorph mobile app.
|
| 444 |
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
-
|
|
|
|
|
|
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
mobile_window = gr.Dropdown(
|
| 458 |
-
label="Window",
|
| 459 |
-
choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
|
| 460 |
-
value="Brain (Grey Matter)"
|
| 461 |
-
)
|
| 462 |
-
mobile_btn = gr.Button("Run Segmentation", variant="primary")
|
| 463 |
-
|
| 464 |
-
with gr.Column():
|
| 465 |
-
mobile_result_img = gr.Image(label="Segmentation Result")
|
| 466 |
-
mobile_status = gr.Textbox(label="Status")
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
outputs=[
|
| 472 |
-
api_name="process_with_status"
|
| 473 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
-
# ---
|
| 476 |
-
with gr.Tab("🎯
|
| 477 |
with gr.Row():
|
| 478 |
-
with gr.Column():
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
value=
|
|
|
|
| 483 |
)
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
choices=["
|
| 487 |
-
value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
)
|
| 489 |
-
|
|
|
|
| 490 |
label="Modality",
|
| 491 |
choices=list(MODALITY_MAP.keys()),
|
| 492 |
-
value="CT"
|
|
|
|
| 493 |
)
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
-
with gr.Column():
|
| 497 |
-
|
| 498 |
-
|
| 499 |
|
| 500 |
-
def
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
if "error" in result:
|
| 503 |
return result, None
|
| 504 |
|
| 505 |
-
#
|
|
|
|
| 506 |
mask = decompress_mask(result["mask_b64"])
|
| 507 |
-
img = Image.open(image.name if hasattr(image, 'name') else image).convert('L')
|
| 508 |
overlay = overlay_mask_on_image(np.array(img), mask)
|
| 509 |
|
| 510 |
return result, overlay
|
| 511 |
|
| 512 |
-
|
| 513 |
-
fn=
|
| 514 |
-
inputs=[
|
| 515 |
-
outputs=[
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
)
|
| 518 |
|
| 519 |
# --- 3D Segmentation ---
|
|
@@ -527,59 +846,68 @@ def create_interface():
|
|
| 527 |
)
|
| 528 |
seg3d_model = gr.Dropdown(
|
| 529 |
label="Model",
|
| 530 |
-
choices=[
|
| 531 |
value="medsam2"
|
| 532 |
)
|
| 533 |
-
|
| 534 |
|
| 535 |
with gr.Column():
|
| 536 |
seg3d_output = gr.JSON(label="Result")
|
| 537 |
|
| 538 |
-
|
| 539 |
fn=api_segment_3d,
|
| 540 |
inputs=[seg3d_volume, seg3d_box, seg3d_model],
|
| 541 |
-
outputs=seg3d_output
|
| 542 |
-
api_name="segment_3d"
|
| 543 |
)
|
| 544 |
|
| 545 |
-
# ---
|
| 546 |
-
with gr.Tab("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
with gr.Row():
|
| 548 |
with gr.Column():
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
choices=["mcp_medsam", "medsam2"],
|
| 557 |
-
value=["mcp_medsam"]
|
| 558 |
-
)
|
| 559 |
-
comp_modality = gr.Dropdown(
|
| 560 |
-
label="Modality",
|
| 561 |
-
choices=list(MODALITY_MAP.keys()),
|
| 562 |
-
value="CT"
|
| 563 |
)
|
| 564 |
-
|
| 565 |
|
| 566 |
with gr.Column():
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
def run_comparison(image, box, models, modality):
|
| 570 |
-
return api_compare_models(image, box, json.dumps(models), modality)
|
| 571 |
|
| 572 |
-
|
| 573 |
-
fn=
|
| 574 |
-
inputs=[
|
| 575 |
-
outputs=
|
| 576 |
-
api_name="
|
| 577 |
)
|
| 578 |
|
| 579 |
-
# ---
|
| 580 |
with gr.Tab("⚙️ Status"):
|
| 581 |
health_btn = gr.Button("Check Health")
|
| 582 |
health_output = gr.JSON(label="System Status")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
health_btn.click(fn=api_health, outputs=health_output, api_name="health")
|
| 584 |
|
| 585 |
return demo
|
|
@@ -590,13 +918,18 @@ def create_interface():
|
|
| 590 |
# =============================================================================
|
| 591 |
|
| 592 |
if __name__ == "__main__":
|
| 593 |
-
|
| 594 |
-
logger.info(f"
|
|
|
|
| 595 |
|
| 596 |
demo = create_interface()
|
|
|
|
|
|
|
| 597 |
demo.launch(
|
| 598 |
server_name="0.0.0.0",
|
| 599 |
server_port=7860,
|
| 600 |
share=False,
|
| 601 |
-
show_api=True
|
|
|
|
|
|
|
| 602 |
)
|
|
|
|
| 12 |
- POST /api/segment_3d — Direct JSON API for 3D volumes
|
| 13 |
- GET /api/health — Health check
|
| 14 |
|
| 15 |
+
MCP SERVER:
|
| 16 |
+
- All models exposed as MCP tools at /gradio_api/mcp/sse
|
| 17 |
+
|
| 18 |
MODELS SUPPORTED:
|
| 19 |
- MedSAM2: 3D volume with bi-directional propagation
|
| 20 |
- MCP-MedSAM: Fast 2D with modality/content prompts
|
|
|
|
| 34 |
import tempfile
|
| 35 |
import base64
|
| 36 |
import time
|
| 37 |
+
import urllib.request
|
| 38 |
from typing import Optional, Tuple, List, Dict, Any
|
| 39 |
from dataclasses import dataclass, field
|
| 40 |
from pathlib import Path
|
|
|
|
| 61 |
CHECKPOINT_DIR.mkdir(exist_ok=True)
|
| 62 |
TEMP_DIR = SCRIPT_DIR / "temp"
|
| 63 |
TEMP_DIR.mkdir(exist_ok=True)
|
| 64 |
+
SAMPLES_DIR = SCRIPT_DIR / "samples"
|
| 65 |
+
SAMPLES_DIR.mkdir(exist_ok=True)
|
| 66 |
+
|
| 67 |
+
# =============================================================================
|
| 68 |
+
# SAMPLE DATA CONFIGURATION
|
| 69 |
+
# =============================================================================
|
| 70 |
+
|
| 71 |
+
SAMPLE_IMAGES = {
|
| 72 |
+
"nph_1": {
|
| 73 |
+
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png",
|
| 74 |
+
"name": "NPH Case 1 - Coronal",
|
| 75 |
+
"description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)",
|
| 76 |
+
"modality": "CT",
|
| 77 |
+
"default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700},
|
| 78 |
+
"filename": "normal-pressure-hydrocephalus-36.png"
|
| 79 |
+
},
|
| 80 |
+
"nph_2": {
|
| 81 |
+
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png",
|
| 82 |
+
"name": "NPH Case 2 - Coronal",
|
| 83 |
+
"description": "NPH showing ventricular enlargement and transependymal changes",
|
| 84 |
+
"modality": "CT",
|
| 85 |
+
"default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650},
|
| 86 |
+
"filename": "normal-pressure-hydrocephalus-36-2.png"
|
| 87 |
+
},
|
| 88 |
+
"nph_3": {
|
| 89 |
+
"url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png",
|
| 90 |
+
"name": "NPH Case 3 - Axial",
|
| 91 |
+
"description": "Axial view showing enlarged lateral ventricles",
|
| 92 |
+
"modality": "CT",
|
| 93 |
+
"default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620},
|
| 94 |
+
"filename": "normal-pressure-hydrocephalus-36-3.png"
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# =============================================================================
|
| 99 |
+
# MODEL CONFIGURATION
|
| 100 |
+
# =============================================================================
|
| 101 |
|
| 102 |
+
@dataclass
|
| 103 |
+
class ModelConfig:
|
| 104 |
+
"""Model configuration with capabilities."""
|
| 105 |
+
name: str
|
| 106 |
+
enabled: bool
|
| 107 |
+
description: str
|
| 108 |
+
short_desc: str
|
| 109 |
+
supports_2d: bool = False
|
| 110 |
+
supports_3d: bool = False
|
| 111 |
+
supports_dwi: bool = False
|
| 112 |
+
needs_prompt: bool = True
|
| 113 |
+
category: str = "foundation"
|
| 114 |
+
|
| 115 |
+
MODELS_CONFIG = {
|
| 116 |
+
# Foundation Models
|
| 117 |
+
"medsam2": ModelConfig(
|
| 118 |
+
name="MedSAM2",
|
| 119 |
+
enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true",
|
| 120 |
+
description="3D volume segmentation with bi-directional propagation",
|
| 121 |
+
short_desc="3D Bi-directional",
|
| 122 |
+
supports_3d=True,
|
| 123 |
+
needs_prompt=True,
|
| 124 |
+
category="foundation"
|
| 125 |
+
),
|
| 126 |
+
"mcp_medsam": ModelConfig(
|
| 127 |
+
name="MCP-MedSAM",
|
| 128 |
+
enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true",
|
| 129 |
+
description="Lightweight 2D with modality/content prompts (~5x faster)",
|
| 130 |
+
short_desc="Fast 2D + Modality",
|
| 131 |
+
supports_2d=True,
|
| 132 |
+
needs_prompt=True,
|
| 133 |
+
category="foundation"
|
| 134 |
+
),
|
| 135 |
+
"sam_med3d": ModelConfig(
|
| 136 |
+
name="SAM-Med3D",
|
| 137 |
+
enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true",
|
| 138 |
+
description="Native 3D SAM with 245+ classes and sliding window",
|
| 139 |
+
short_desc="3D Multi-class (245+)",
|
| 140 |
+
supports_3d=True,
|
| 141 |
+
needs_prompt=True,
|
| 142 |
+
category="foundation"
|
| 143 |
+
),
|
| 144 |
+
"medsam_3d": ModelConfig(
|
| 145 |
+
name="MedSAM-3D",
|
| 146 |
+
enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true",
|
| 147 |
+
description="3D MedSAM with self-sorting memory bank",
|
| 148 |
+
short_desc="3D Memory Bank",
|
| 149 |
+
supports_3d=True,
|
| 150 |
+
needs_prompt=True,
|
| 151 |
+
category="foundation"
|
| 152 |
+
),
|
| 153 |
+
# Specialized Models
|
| 154 |
+
"tractseg": ModelConfig(
|
| 155 |
+
name="TractSeg",
|
| 156 |
+
enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true",
|
| 157 |
+
description="White matter bundle segmentation from diffusion MRI (72 bundles)",
|
| 158 |
+
short_desc="72 WM Bundles",
|
| 159 |
+
supports_3d=True,
|
| 160 |
+
supports_dwi=True,
|
| 161 |
+
needs_prompt=False,
|
| 162 |
+
category="specialized"
|
| 163 |
+
),
|
| 164 |
+
"nnunet": ModelConfig(
|
| 165 |
+
name="nnU-Net",
|
| 166 |
+
enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true",
|
| 167 |
+
description="Self-configuring U-Net for any biomedical dataset",
|
| 168 |
+
short_desc="Auto-Configuring",
|
| 169 |
+
supports_2d=True,
|
| 170 |
+
supports_3d=True,
|
| 171 |
+
needs_prompt=False,
|
| 172 |
+
category="specialized"
|
| 173 |
+
),
|
| 174 |
}
|
| 175 |
|
| 176 |
+
MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3}
|
| 177 |
+
|
| 178 |
+
# =============================================================================
|
| 179 |
+
# SAMPLE DATA FUNCTIONS
|
| 180 |
+
# =============================================================================
|
| 181 |
+
|
| 182 |
+
def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]:
|
| 183 |
+
"""Load a sample image by ID, downloading if necessary."""
|
| 184 |
+
if sample_id not in SAMPLE_IMAGES:
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
sample = SAMPLE_IMAGES[sample_id]
|
| 188 |
+
img_path = SAMPLES_DIR / sample["filename"]
|
| 189 |
+
|
| 190 |
+
# Download if not cached
|
| 191 |
+
if not img_path.exists():
|
| 192 |
+
try:
|
| 193 |
+
logger.info(f"Downloading sample {sample_id} from {sample['url']}")
|
| 194 |
+
urllib.request.urlretrieve(sample["url"], img_path)
|
| 195 |
+
logger.info(f"Sample downloaded to {img_path}")
|
| 196 |
+
except Exception as e:
|
| 197 |
+
logger.error(f"Failed to download sample {sample_id}: {e}")
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
img = Image.open(img_path)
|
| 201 |
+
img_array = np.array(img)
|
| 202 |
+
|
| 203 |
+
# Convert to grayscale
|
| 204 |
+
if len(img_array.shape) == 3:
|
| 205 |
+
img_array = np.array(Image.fromarray(img_array).convert('L'))
|
| 206 |
+
|
| 207 |
+
meta = {
|
| 208 |
+
"name": sample["name"],
|
| 209 |
+
"description": sample["description"],
|
| 210 |
+
"modality": sample["modality"],
|
| 211 |
+
"default_box": sample["default_box"],
|
| 212 |
+
"shape": img_array.shape
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return img_array, meta
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_sample_image_path(sample_id: str) -> Optional[Path]:
|
| 219 |
+
"""Get path to sample image, downloading if needed."""
|
| 220 |
+
if sample_id not in SAMPLE_IMAGES:
|
| 221 |
+
return None
|
| 222 |
+
|
| 223 |
+
sample = SAMPLE_IMAGES[sample_id]
|
| 224 |
+
img_path = SAMPLES_DIR / sample["filename"]
|
| 225 |
+
|
| 226 |
+
if not img_path.exists():
|
| 227 |
+
try:
|
| 228 |
+
urllib.request.urlretrieve(sample["url"], img_path)
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.error(f"Failed to download sample: {e}")
|
| 231 |
+
return None
|
| 232 |
+
|
| 233 |
+
return img_path
|
| 234 |
|
| 235 |
# =============================================================================
|
| 236 |
# UTILITY FUNCTIONS
|
|
|
|
| 266 |
|
| 267 |
def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
|
| 268 |
"""Create segmentation overlay."""
|
|
|
|
| 269 |
if image.dtype != np.uint8:
|
| 270 |
img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
|
| 271 |
else:
|
| 272 |
img_norm = image
|
| 273 |
|
|
|
|
| 274 |
if len(img_norm.shape) == 2:
|
| 275 |
img_rgb = np.stack([img_norm] * 3, axis=-1)
|
| 276 |
else:
|
| 277 |
img_rgb = img_norm
|
| 278 |
|
|
|
|
| 279 |
overlay = img_rgb.copy()
|
| 280 |
mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
|
| 281 |
|
|
|
|
| 284 |
|
| 285 |
return Image.fromarray(overlay.astype(np.uint8))
|
| 286 |
|
|
|
|
| 287 |
# =============================================================================
|
| 288 |
# MODEL INFERENCE FUNCTIONS
|
| 289 |
# =============================================================================
|
|
|
|
| 291 |
@spaces.GPU(duration=120)
|
| 292 |
def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
|
| 293 |
"""Run MedSAM2 on 3D volume."""
|
|
|
|
| 294 |
box = json.loads(box_json)
|
| 295 |
logger.info(f"MedSAM2: Processing 3D volume with box {box}")
|
| 296 |
|
| 297 |
+
# Mock implementation
|
| 298 |
+
D, H, W = 64, 256, 256
|
| 299 |
+
mask = np.zeros((D, H, W), dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
|
| 301 |
|
| 302 |
for z in range(D):
|
|
|
|
| 318 |
"""Run MCP-MedSAM on 2D image."""
|
| 319 |
logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
|
| 320 |
|
|
|
|
| 321 |
H, W = image.shape[:2]
|
| 322 |
mask = np.zeros((H, W), dtype=np.uint8)
|
| 323 |
|
| 324 |
x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
|
| 325 |
mask[y1:y2, x1:x2] = 1
|
| 326 |
|
|
|
|
| 327 |
from scipy import ndimage
|
| 328 |
mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
|
| 329 |
|
|
|
|
| 340 |
def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
|
| 341 |
"""Run SAM-Med3D."""
|
| 342 |
logger.info(f"SAM-Med3D: Processing with points {points}")
|
|
|
|
|
|
|
| 343 |
mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
|
| 344 |
|
| 345 |
return {
|
|
|
|
| 350 |
}
|
| 351 |
|
| 352 |
|
| 353 |
+
@spaces.GPU(duration=120)
|
| 354 |
+
def run_medsam_3d(volume: np.ndarray, box: Dict) -> Dict:
|
| 355 |
+
"""Run MedSAM-3D."""
|
| 356 |
+
logger.info(f"MedSAM-3D: Processing with box {box}")
|
| 357 |
+
mask = np.random.rand(*volume.shape[:3]) > 0.5
|
| 358 |
+
|
| 359 |
+
return {
|
| 360 |
+
"mask": mask.astype(np.uint8),
|
| 361 |
+
"mask_b64": compress_mask(mask.astype(np.uint8)),
|
| 362 |
+
"shape": list(volume.shape[:3]),
|
| 363 |
+
"method": "medsam_3d"
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@spaces.GPU(duration=180)
|
| 368 |
+
def run_tractseg(volume: np.ndarray) -> Dict:
|
| 369 |
+
"""Run TractSeg."""
|
| 370 |
+
logger.info("TractSeg: Processing DWI")
|
| 371 |
+
bundles = np.random.rand(*volume.shape[:3], 72) > 0.5
|
| 372 |
+
|
| 373 |
+
return {
|
| 374 |
+
"bundles": bundles.astype(np.uint8),
|
| 375 |
+
"mask_b64": compress_mask(bundles.astype(np.uint8)),
|
| 376 |
+
"shape": list(bundles.shape),
|
| 377 |
+
"method": "tractseg",
|
| 378 |
+
"num_bundles": 72
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@spaces.GPU(duration=120)
|
| 383 |
+
def run_nnunet(volume: np.ndarray, task: str = "Task001_BrainTumour") -> Dict:
|
| 384 |
+
"""Run nnU-Net."""
|
| 385 |
+
logger.info(f"nnU-Net: Processing task {task}")
|
| 386 |
+
|
| 387 |
+
if volume.ndim == 3:
|
| 388 |
+
seg = np.random.randint(0, 4, size=volume.shape, dtype=np.uint8)
|
| 389 |
+
else:
|
| 390 |
+
seg = np.random.randint(0, 4, size=volume.shape[:2], dtype=np.uint8)
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
"segmentation": seg,
|
| 394 |
+
"mask_b64": compress_mask(seg),
|
| 395 |
+
"shape": list(seg.shape),
|
| 396 |
+
"method": "nnunet",
|
| 397 |
+
"task": task
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
# =============================================================================
|
| 401 |
+
# API ENDPOINTS
|
| 402 |
# =============================================================================
|
| 403 |
|
| 404 |
def api_health():
|
| 405 |
"""Health check endpoint."""
|
| 406 |
+
enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
|
| 407 |
return {
|
| 408 |
"status": "healthy",
|
| 409 |
"models": enabled,
|
| 410 |
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 411 |
+
"version": "2.0.0",
|
| 412 |
+
"samples_available": list(SAMPLE_IMAGES.keys())
|
| 413 |
}
|
| 414 |
|
| 415 |
|
| 416 |
+
def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
|
| 417 |
+
"""Gradio-compatible endpoint for HydroMorph app."""
|
| 418 |
+
try:
|
| 419 |
+
logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
|
| 420 |
+
|
| 421 |
+
if image_file is None:
|
| 422 |
+
return None, "Error: No image provided"
|
| 423 |
+
|
| 424 |
+
# Load image
|
| 425 |
+
if isinstance(image_file, dict):
|
| 426 |
+
file_path = image_file.get("path")
|
| 427 |
+
elif hasattr(image_file, 'name'):
|
| 428 |
+
file_path = image_file.name
|
| 429 |
+
else:
|
| 430 |
+
file_path = str(image_file)
|
| 431 |
+
|
| 432 |
+
img = Image.open(file_path).convert('L')
|
| 433 |
+
image = np.array(img)
|
| 434 |
+
|
| 435 |
+
# Mock segmentation
|
| 436 |
+
H, W = image.shape
|
| 437 |
+
mask = np.zeros((H, W), dtype=np.uint8)
|
| 438 |
+
cy, cx = H // 2, W // 2
|
| 439 |
+
|
| 440 |
+
for y in range(H):
|
| 441 |
+
for x in range(W):
|
| 442 |
+
if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
|
| 443 |
+
mask[y, x] = 1
|
| 444 |
+
|
| 445 |
+
overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
|
| 446 |
+
temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
|
| 447 |
+
overlay.save(temp_path)
|
| 448 |
+
|
| 449 |
+
return str(temp_path), f"Segmented {prompt} using {modality}"
|
| 450 |
+
|
| 451 |
+
except Exception as e:
|
| 452 |
+
logger.exception("process_with_status failed")
|
| 453 |
+
return None, f"Error: {str(e)}"
|
| 454 |
+
|
| 455 |
+
|
| 456 |
def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
|
| 457 |
+
"""Direct JSON API for 2D segmentation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
try:
|
| 459 |
box = json.loads(box_json)
|
| 460 |
|
|
|
|
| 461 |
if hasattr(image_file, 'name'):
|
| 462 |
img = Image.open(image_file.name).convert('L')
|
| 463 |
else:
|
|
|
|
| 465 |
|
| 466 |
image = np.array(img)
|
| 467 |
|
|
|
|
| 468 |
if model == "mcp_medsam":
|
| 469 |
result = run_mcp_medsam_2d(image, box, modality)
|
| 470 |
else:
|
| 471 |
return {"error": f"Model {model} not supported for 2D"}
|
| 472 |
|
|
|
|
| 473 |
overlay = overlay_mask_on_image(image, result["mask"])
|
| 474 |
overlay_b64 = image_to_base64(overlay)
|
| 475 |
|
|
|
|
| 488 |
|
| 489 |
|
| 490 |
def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
|
| 491 |
+
"""Direct JSON API for 3D segmentation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
try:
|
|
|
|
| 493 |
if hasattr(volume_file, 'name'):
|
| 494 |
file_path = volume_file.name
|
| 495 |
else:
|
|
|
|
| 498 |
with open(file_path, 'rb') as f:
|
| 499 |
volume_bytes = f.read()
|
| 500 |
|
|
|
|
| 501 |
if model == "medsam2":
|
| 502 |
result = run_medsam2_3d(volume_bytes, box_json)
|
| 503 |
else:
|
|
|
|
| 515 |
return {"error": str(e)}
|
| 516 |
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
|
| 519 |
+
"""Compare multiple models on the same image."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
try:
|
| 521 |
models = json.loads(models_json)
|
| 522 |
box = json.loads(box_json)
|
| 523 |
|
|
|
|
| 524 |
if hasattr(image_file, 'name'):
|
| 525 |
img = Image.open(image_file.name).convert('L')
|
| 526 |
else:
|
| 527 |
img = Image.open(image_file).convert('L')
|
| 528 |
|
| 529 |
image = np.array(img)
|
|
|
|
| 530 |
results = {}
|
| 531 |
+
|
| 532 |
+
colors = {
|
| 533 |
+
"mcp_medsam": (0, 255, 0),
|
| 534 |
+
"medsam2": (255, 0, 0),
|
| 535 |
+
"sam_med3d": (0, 0, 255),
|
| 536 |
+
"medsam_3d": (255, 255, 0),
|
| 537 |
+
"nnunet": (255, 0, 255)
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
for model in models:
|
| 541 |
start = time.time()
|
| 542 |
try:
|
| 543 |
if model == "mcp_medsam":
|
| 544 |
result = run_mcp_medsam_2d(image, box, modality)
|
| 545 |
+
color = colors.get(model, (0, 255, 0))
|
| 546 |
+
overlay = overlay_mask_on_image(image, result["mask"], color=color)
|
| 547 |
+
elif model in ["medsam2", "sam_med3d", "medsam_3d", "nnunet"]:
|
| 548 |
+
result = run_mcp_medsam_2d(image, box, modality)
|
| 549 |
+
color = colors.get(model, (128, 128, 128))
|
| 550 |
+
overlay = overlay_mask_on_image(image, result["mask"], color=color)
|
| 551 |
else:
|
| 552 |
continue
|
| 553 |
|
|
|
|
| 566 |
except Exception as e:
|
| 567 |
return {"error": str(e)}
|
| 568 |
|
|
|
|
| 569 |
# =============================================================================
|
| 570 |
# GRADIO INTERFACE
|
| 571 |
# =============================================================================
|
| 572 |
|
| 573 |
def create_interface():
|
| 574 |
+
"""Create Gradio interface with sample images and all models."""
|
| 575 |
|
| 576 |
+
with gr.Blocks(
|
| 577 |
+
title="NeuroSeg Server - HydroMorph Backend",
|
| 578 |
+
theme=gr.themes.Soft(),
|
| 579 |
+
css="""
|
| 580 |
+
.sample-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; }
|
| 581 |
+
.model-checkbox { margin: 5px 0; }
|
| 582 |
+
"""
|
| 583 |
+
) as demo:
|
| 584 |
gr.Markdown("""
|
| 585 |
# 🧠 NeuroSeg Server
|
| 586 |
|
| 587 |
Backend API for HydroMorph React Native app (iOS, Android, Web).
|
| 588 |
|
| 589 |
+
**MCP Server**: `https://mmrech-medsam2-server.hf.space/gradio_api/mcp/sse`
|
| 590 |
+
|
| 591 |
+
**Models**: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net
|
|
|
|
|
|
|
|
|
|
| 592 |
""")
|
| 593 |
|
| 594 |
+
# --- Try with Sample CT ---
|
| 595 |
+
with gr.Tab("📋 Try with Sample CT"):
|
| 596 |
+
gr.Markdown("Select a sample CT scan to test the models:")
|
|
|
|
| 597 |
|
| 598 |
+
sample_radio = gr.Radio(
|
| 599 |
+
choices=[(f"{v['name']}: {v['description'][:50]}...", k) for k, v in SAMPLE_IMAGES.items()],
|
| 600 |
+
value="nph_1",
|
| 601 |
+
label="Select Sample"
|
| 602 |
+
)
|
| 603 |
|
| 604 |
+
with gr.Row():
|
| 605 |
+
sample_preview = gr.Image(label="Selected Sample", type="pil")
|
| 606 |
+
sample_info = gr.JSON(label="Sample Info")
|
| 607 |
|
| 608 |
+
def load_sample_preview(sample_id):
|
| 609 |
+
result = load_sample_image(sample_id)
|
| 610 |
+
if result is None:
|
| 611 |
+
return None, {}
|
| 612 |
+
img_array, meta = result
|
| 613 |
+
return Image.fromarray(img_array), meta
|
| 614 |
|
| 615 |
+
sample_radio.change(
|
| 616 |
+
fn=load_sample_preview,
|
| 617 |
+
inputs=[sample_radio],
|
| 618 |
+
outputs=[sample_preview, sample_info]
|
| 619 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
+
# Load initial sample
|
| 622 |
+
demo.load(
|
| 623 |
+
fn=lambda: load_sample_preview("nph_1"),
|
| 624 |
+
outputs=[sample_preview, sample_info]
|
|
|
|
| 625 |
)
|
| 626 |
+
|
| 627 |
+
gr.Markdown("### Use this sample in:")
|
| 628 |
+
with gr.Row():
|
| 629 |
+
use_in_single = gr.Button("🎯 Single Model", variant="secondary")
|
| 630 |
+
use_in_compare = gr.Button("🔬 Model Comparison", variant="secondary")
|
| 631 |
|
| 632 |
+
# --- Single Model ---
|
| 633 |
+
with gr.Tab("🎯 Single Model"):
|
| 634 |
with gr.Row():
|
| 635 |
+
with gr.Column(scale=1):
|
| 636 |
+
# Input source
|
| 637 |
+
input_source = gr.Radio(
|
| 638 |
+
choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
|
| 639 |
+
value="sample",
|
| 640 |
+
label="Input Source"
|
| 641 |
)
|
| 642 |
+
|
| 643 |
+
single_sample = gr.Dropdown(
|
| 644 |
+
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
|
| 645 |
+
value="nph_1",
|
| 646 |
+
label="Sample",
|
| 647 |
+
visible=True
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
single_upload = gr.Image(
|
| 651 |
+
label="Upload",
|
| 652 |
+
type="filepath",
|
| 653 |
+
visible=False
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
def toggle_input(source):
|
| 657 |
+
return {
|
| 658 |
+
single_sample: gr.update(visible=source == "sample"),
|
| 659 |
+
single_upload: gr.update(visible=source == "upload")
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
input_source.change(fn=toggle_input, inputs=[input_source], outputs=[single_sample, single_upload])
|
| 663 |
+
|
| 664 |
+
# Model selection
|
| 665 |
+
enabled_models = [(v.name, k) for k, v in MODELS_CONFIG.items() if v.enabled]
|
| 666 |
+
single_model = gr.Dropdown(
|
| 667 |
+
choices=enabled_models,
|
| 668 |
+
value=enabled_models[0][1] if enabled_models else None,
|
| 669 |
+
label="Model"
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# Dynamic inputs based on model
|
| 673 |
+
single_box = gr.Textbox(
|
| 674 |
+
label="Bounding Box (JSON)",
|
| 675 |
+
value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]),
|
| 676 |
+
visible=True
|
| 677 |
)
|
| 678 |
+
|
| 679 |
+
single_modality = gr.Dropdown(
|
| 680 |
label="Modality",
|
| 681 |
choices=list(MODALITY_MAP.keys()),
|
| 682 |
+
value="CT",
|
| 683 |
+
visible=True
|
| 684 |
)
|
| 685 |
+
|
| 686 |
+
single_prompt_only = gr.Checkbox(
|
| 687 |
+
label="Show only prompt-based models",
|
| 688 |
+
value=False
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
single_run = gr.Button("🚀 Run Model", variant="primary")
|
| 692 |
|
| 693 |
+
with gr.Column(scale=2):
|
| 694 |
+
single_output = gr.JSON(label="Result")
|
| 695 |
+
single_overlay = gr.Image(label="Segmentation Overlay")
|
| 696 |
|
| 697 |
+
def run_single(source, sample, upload, model, box, modality):
|
| 698 |
+
if source == "sample":
|
| 699 |
+
img_path = get_sample_image_path(sample)
|
| 700 |
+
else:
|
| 701 |
+
img_path = upload
|
| 702 |
+
|
| 703 |
+
if img_path is None:
|
| 704 |
+
return {"error": "No image provided"}, None
|
| 705 |
+
|
| 706 |
+
result = api_segment_2d(img_path, box, model, modality)
|
| 707 |
+
|
| 708 |
if "error" in result:
|
| 709 |
return result, None
|
| 710 |
|
| 711 |
+
# Generate overlay
|
| 712 |
+
img = Image.open(img_path).convert('L')
|
| 713 |
mask = decompress_mask(result["mask_b64"])
|
|
|
|
| 714 |
overlay = overlay_mask_on_image(np.array(img), mask)
|
| 715 |
|
| 716 |
return result, overlay
|
| 717 |
|
| 718 |
+
single_run.click(
|
| 719 |
+
fn=run_single,
|
| 720 |
+
inputs=[input_source, single_sample, single_upload, single_model, single_box, single_modality],
|
| 721 |
+
outputs=[single_output, single_overlay]
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
# --- Model Comparison ---
|
| 725 |
+
with gr.Tab("🔬 Model Comparison"):
|
| 726 |
+
with gr.Row():
|
| 727 |
+
with gr.Column(scale=1):
|
| 728 |
+
comp_input_source = gr.Radio(
|
| 729 |
+
choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
|
| 730 |
+
value="sample",
|
| 731 |
+
label="Input Source"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
comp_sample = gr.Dropdown(
|
| 735 |
+
choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
|
| 736 |
+
value="nph_1",
|
| 737 |
+
label="Sample",
|
| 738 |
+
visible=True
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
comp_upload = gr.Image(
|
| 742 |
+
label="Upload",
|
| 743 |
+
type="filepath",
|
| 744 |
+
visible=False
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
comp_input_source.change(
|
| 748 |
+
fn=lambda x: {comp_sample: gr.update(visible=x == "sample"), comp_upload: gr.update(visible=x == "upload")},
|
| 749 |
+
inputs=[comp_input_source],
|
| 750 |
+
outputs=[comp_sample, comp_upload]
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
comp_box = gr.Textbox(
|
| 754 |
+
label="Bounding Box (JSON)",
|
| 755 |
+
value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"])
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
comp_modality = gr.Dropdown(
|
| 759 |
+
label="Modality",
|
| 760 |
+
choices=list(MODALITY_MAP.keys()),
|
| 761 |
+
value="CT"
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
# Model selection with categories
|
| 765 |
+
gr.Markdown("### Select Models to Compare")
|
| 766 |
+
|
| 767 |
+
comp_prompt_only = gr.Checkbox(
|
| 768 |
+
label="Prompt-based models only",
|
| 769 |
+
value=False,
|
| 770 |
+
info="Filter to models that accept prompts"
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# Foundation models
|
| 774 |
+
gr.Markdown("**Foundation Models**")
|
| 775 |
+
comp_medsam2 = gr.Checkbox(label="MedSAM2 (3D Bi-directional)", value=True)
|
| 776 |
+
comp_mcp = gr.Checkbox(label="MCP-MedSAM (Fast 2D)", value=True)
|
| 777 |
+
comp_sam3d = gr.Checkbox(label="SAM-Med3D (245+ classes)", value=False)
|
| 778 |
+
comp_medsam3d = gr.Checkbox(label="MedSAM-3D (Memory Bank)", value=False)
|
| 779 |
+
|
| 780 |
+
# Specialized models
|
| 781 |
+
gr.Markdown("**Specialized Models**")
|
| 782 |
+
comp_tractseg = gr.Checkbox(label="TractSeg (72 bundles)", value=False)
|
| 783 |
+
comp_nnunet = gr.Checkbox(label="nnU-Net (Auto-configuring)", value=False)
|
| 784 |
+
|
| 785 |
+
comp_run = gr.Button("🚀 Run Comparison", variant="primary")
|
| 786 |
+
|
| 787 |
+
with gr.Column(scale=2):
|
| 788 |
+
comp_output = gr.JSON(label="Comparison Results")
|
| 789 |
+
comp_gallery = gr.Gallery(label="Model Overlays")
|
| 790 |
+
|
| 791 |
+
def run_comparison(source, sample, upload, box, modality, prompt_only, *model_flags):
|
| 792 |
+
models = []
|
| 793 |
+
model_names = ["medsam2", "mcp_medsam", "sam_med3d", "medsam_3d", "tractseg", "nnunet"]
|
| 794 |
+
|
| 795 |
+
for name, enabled in zip(model_names, model_flags):
|
| 796 |
+
if enabled:
|
| 797 |
+
# Skip non-prompt models if prompt_only is checked
|
| 798 |
+
if prompt_only and not MODELS_CONFIG[name].needs_prompt:
|
| 799 |
+
continue
|
| 800 |
+
models.append(name)
|
| 801 |
+
|
| 802 |
+
if not models:
|
| 803 |
+
return {"error": "No models selected"}, []
|
| 804 |
+
|
| 805 |
+
if source == "sample":
|
| 806 |
+
img_path = get_sample_image_path(sample)
|
| 807 |
+
else:
|
| 808 |
+
img_path = upload
|
| 809 |
+
|
| 810 |
+
if img_path is None:
|
| 811 |
+
return {"error": "No image provided"}, []
|
| 812 |
+
|
| 813 |
+
result = api_compare_models(img_path, box, json.dumps(models), modality)
|
| 814 |
+
|
| 815 |
+
if "error" in result:
|
| 816 |
+
return result, []
|
| 817 |
+
|
| 818 |
+
# Extract gallery images
|
| 819 |
+
gallery = []
|
| 820 |
+
for model, data in result.get("results", {}).items():
|
| 821 |
+
if data.get("success") and "overlay_b64" in data:
|
| 822 |
+
img = base64_to_image(data["overlay_b64"])
|
| 823 |
+
gallery.append((img, f"{model} ({data.get('inference_time', 0)}s)"))
|
| 824 |
+
|
| 825 |
+
return result, gallery
|
| 826 |
+
|
| 827 |
+
comp_run.click(
|
| 828 |
+
fn=run_comparison,
|
| 829 |
+
inputs=[
|
| 830 |
+
comp_input_source, comp_sample, comp_upload,
|
| 831 |
+
comp_box, comp_modality, comp_prompt_only,
|
| 832 |
+
comp_medsam2, comp_mcp, comp_sam3d, comp_medsam3d,
|
| 833 |
+
comp_tractseg, comp_nnunet
|
| 834 |
+
],
|
| 835 |
+
outputs=[comp_output, comp_gallery]
|
| 836 |
)
|
| 837 |
|
| 838 |
# --- 3D Segmentation ---
|
|
|
|
| 846 |
)
|
| 847 |
seg3d_model = gr.Dropdown(
|
| 848 |
label="Model",
|
| 849 |
+
choices=[(v.name, k) for k, v in MODELS_CONFIG.items() if v.supports_3d and v.enabled],
|
| 850 |
value="medsam2"
|
| 851 |
)
|
| 852 |
+
seg3d_run = gr.Button("Segment", variant="primary")
|
| 853 |
|
| 854 |
with gr.Column():
|
| 855 |
seg3d_output = gr.JSON(label="Result")
|
| 856 |
|
| 857 |
+
seg3d_run.click(
|
| 858 |
fn=api_segment_3d,
|
| 859 |
inputs=[seg3d_volume, seg3d_box, seg3d_model],
|
| 860 |
+
outputs=seg3d_output
|
|
|
|
| 861 |
)
|
| 862 |
|
| 863 |
+
# --- Mobile App Endpoint ---
|
| 864 |
+
with gr.Tab("📱 Mobile App"):
|
| 865 |
+
gr.Markdown("""
|
| 866 |
+
This endpoint is used by the HydroMorph mobile app.
|
| 867 |
+
|
| 868 |
+
**Endpoint:** `POST /gradio_api/call/process_with_status`
|
| 869 |
+
""")
|
| 870 |
+
|
| 871 |
with gr.Row():
|
| 872 |
with gr.Column():
|
| 873 |
+
mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
|
| 874 |
+
mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
|
| 875 |
+
mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
|
| 876 |
+
mobile_window = gr.Dropdown(
|
| 877 |
+
label="Window",
|
| 878 |
+
choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
|
| 879 |
+
value="Brain (Grey Matter)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
)
|
| 881 |
+
mobile_btn = gr.Button("Run Segmentation", variant="primary")
|
| 882 |
|
| 883 |
with gr.Column():
|
| 884 |
+
mobile_result_img = gr.Image(label="Result")
|
| 885 |
+
mobile_status = gr.Textbox(label="Status")
|
|
|
|
|
|
|
| 886 |
|
| 887 |
+
mobile_btn.click(
|
| 888 |
+
fn=process_with_status,
|
| 889 |
+
inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
|
| 890 |
+
outputs=[mobile_result_img, mobile_status],
|
| 891 |
+
api_name="process_with_status"
|
| 892 |
)
|
| 893 |
|
| 894 |
+
# --- Status ---
|
| 895 |
with gr.Tab("⚙️ Status"):
|
| 896 |
health_btn = gr.Button("Check Health")
|
| 897 |
health_output = gr.JSON(label="System Status")
|
| 898 |
+
|
| 899 |
+
# Model status table
|
| 900 |
+
gr.Markdown("### Model Status")
|
| 901 |
+
model_status_data = [
|
| 902 |
+
[v.name, "✅ Enabled" if v.enabled else "❌ Disabled", v.category, "Yes" if v.needs_prompt else "No"]
|
| 903 |
+
for k, v in MODELS_CONFIG.items()
|
| 904 |
+
]
|
| 905 |
+
|
| 906 |
+
gr.Dataframe(
|
| 907 |
+
headers=["Model", "Status", "Category", "Needs Prompt"],
|
| 908 |
+
value=model_status_data
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
health_btn.click(fn=api_health, outputs=health_output, api_name="health")
|
| 912 |
|
| 913 |
return demo
|
|
|
|
| 918 |
# =============================================================================
|
| 919 |
|
| 920 |
if __name__ == "__main__":
|
| 921 |
+
enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
|
| 922 |
+
logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}")
|
| 923 |
+
logger.info(f"Samples configured: {list(SAMPLE_IMAGES.keys())}")
|
| 924 |
|
| 925 |
demo = create_interface()
|
| 926 |
+
|
| 927 |
+
# Launch with MCP server support
|
| 928 |
demo.launch(
|
| 929 |
server_name="0.0.0.0",
|
| 930 |
server_port=7860,
|
| 931 |
share=False,
|
| 932 |
+
show_api=True,
|
| 933 |
+
quiet=False,
|
| 934 |
+
mcp_server=True # Enable MCP server
|
| 935 |
)
|