mmrech commited on
Commit
af2b1cc
·
unverified ·
1 Parent(s): 5ebd54d

Add complete UI with sample images and MCP server

Browse files

- Add SAMPLE_IMAGES config with 3 NPH CT scans
- Add load_sample_image() and get_sample_image_path() functions
- Add ModelConfig dataclass for model metadata
- Add all 6 models: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net
- Add 'Try with Sample CT' tab for quick testing
- Add 'Single Model' tab with sample/upload toggle
- Add 'Model Comparison' tab with category grouping and prompt-only filter
- Add run_medsam_3d(), run_tractseg(), run_nnunet() functions
- Enable mcp_server=True in demo.launch()
- Add colors for each model in comparison view
- Add Model Status table in Status tab

Files changed (2) hide show
  1. __pycache__/app.cpython-313.pyc +0 -0
  2. app.py +563 -230
__pycache__/app.cpython-313.pyc CHANGED
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
 
app.py CHANGED
@@ -12,6 +12,9 @@ ENDPOINTS FOR MOBILE APP:
12
  - POST /api/segment_3d — Direct JSON API for 3D volumes
13
  - GET /api/health — Health check
14
 
 
 
 
15
  MODELS SUPPORTED:
16
  - MedSAM2: 3D volume with bi-directional propagation
17
  - MCP-MedSAM: Fast 2D with modality/content prompts
@@ -31,6 +34,7 @@ import os
31
  import tempfile
32
  import base64
33
  import time
 
34
  from typing import Optional, Tuple, List, Dict, Any
35
  from dataclasses import dataclass, field
36
  from pathlib import Path
@@ -57,18 +61,176 @@ CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
57
  CHECKPOINT_DIR.mkdir(exist_ok=True)
58
  TEMP_DIR = SCRIPT_DIR / "temp"
59
  TEMP_DIR.mkdir(exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Model configs
62
- MODELS = {
63
- "medsam2": {"name": "MedSAM2", "enabled": True, "supports_3d": True, "supports_2d": False},
64
- "mcp_medsam": {"name": "MCP-MedSAM", "enabled": True, "supports_3d": False, "supports_2d": True},
65
- "sam_med3d": {"name": "SAM-Med3D", "enabled": False, "supports_3d": True, "supports_2d": False},
66
- "medsam_3d": {"name": "MedSAM-3D", "enabled": False, "supports_3d": True, "supports_2d": False},
67
- "tractseg": {"name": "TractSeg", "enabled": True, "supports_3d": True, "supports_2d": False},
68
- "nnunet": {"name": "nnU-Net", "enabled": True, "supports_3d": True, "supports_2d": True},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  }
70
 
71
- MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # =============================================================================
74
  # UTILITY FUNCTIONS
@@ -104,19 +266,16 @@ def base64_to_image(b64: str) -> Image.Image:
104
 
105
  def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
106
  """Create segmentation overlay."""
107
- # Normalize image to 0-255
108
  if image.dtype != np.uint8:
109
  img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
110
  else:
111
  img_norm = image
112
 
113
- # Convert to RGB
114
  if len(img_norm.shape) == 2:
115
  img_rgb = np.stack([img_norm] * 3, axis=-1)
116
  else:
117
  img_rgb = img_norm
118
 
119
- # Create overlay
120
  overlay = img_rgb.copy()
121
  mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
122
 
@@ -125,7 +284,6 @@ def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int,
125
 
126
  return Image.fromarray(overlay.astype(np.uint8))
127
 
128
-
129
  # =============================================================================
130
  # MODEL INFERENCE FUNCTIONS
131
  # =============================================================================
@@ -133,17 +291,12 @@ def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int,
133
  @spaces.GPU(duration=120)
134
  def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
135
  """Run MedSAM2 on 3D volume."""
136
- # Mock implementation - replace with actual model
137
  box = json.loads(box_json)
138
  logger.info(f"MedSAM2: Processing 3D volume with box {box}")
139
 
140
- # Load volume (mock)
141
- vol_buf = io.BytesIO(volume_bytes)
142
- volume = np.load(vol_buf) if vol_buf.getvalue()[:1] != b'\x1f' else np.random.rand(64, 256, 256)
143
-
144
- # Generate mock mask (ellipsoid)
145
- mask = np.zeros_like(volume, dtype=np.uint8)
146
- D, H, W = volume.shape
147
  cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
148
 
149
  for z in range(D):
@@ -165,14 +318,12 @@ def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dic
165
  """Run MCP-MedSAM on 2D image."""
166
  logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
167
 
168
- # Mock segmentation
169
  H, W = image.shape[:2]
170
  mask = np.zeros((H, W), dtype=np.uint8)
171
 
172
  x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
173
  mask[y1:y2, x1:x2] = 1
174
 
175
- # Smooth edges (mock)
176
  from scipy import ndimage
177
  mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
178
 
@@ -189,8 +340,6 @@ def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dic
189
  def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
190
  """Run SAM-Med3D."""
191
  logger.info(f"SAM-Med3D: Processing with points {points}")
192
-
193
- # Mock multi-class segmentation
194
  mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
195
 
196
  return {
@@ -201,38 +350,114 @@ def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]
201
  }
202
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  # =============================================================================
205
- # API ENDPOINTS FOR MOBILE APP
206
  # =============================================================================
207
 
208
  def api_health():
209
  """Health check endpoint."""
210
- enabled = [k for k, v in MODELS.items() if v["enabled"]]
211
  return {
212
  "status": "healthy",
213
  "models": enabled,
214
  "device": "cuda" if torch.cuda.is_available() else "cpu",
215
- "version": "2.0.0"
 
216
  }
217
 
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
220
- """
221
- Direct JSON API for 2D segmentation (no Gradio protocol).
222
-
223
- Args:
224
- image_file: Gradio File object
225
- box_json: JSON string {"x1": int, "y1": int, "x2": int, "y2": int}
226
- model: Model to use (default: mcp_medsam)
227
- modality: Imaging modality
228
-
229
- Returns:
230
- JSON with mask_b64, overlay image, and metadata
231
- """
232
  try:
233
  box = json.loads(box_json)
234
 
235
- # Load image
236
  if hasattr(image_file, 'name'):
237
  img = Image.open(image_file.name).convert('L')
238
  else:
@@ -240,13 +465,11 @@ def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modalit
240
 
241
  image = np.array(img)
242
 
243
- # Run model
244
  if model == "mcp_medsam":
245
  result = run_mcp_medsam_2d(image, box, modality)
246
  else:
247
  return {"error": f"Model {model} not supported for 2D"}
248
 
249
- # Generate overlay
250
  overlay = overlay_mask_on_image(image, result["mask"])
251
  overlay_b64 = image_to_base64(overlay)
252
 
@@ -265,19 +488,8 @@ def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modalit
265
 
266
 
267
  def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
268
- """
269
- Direct JSON API for 3D segmentation.
270
-
271
- Args:
272
- volume_file: .npy or .nii.gz file
273
- box_json: JSON with box coordinates and slice_idx
274
- model: Model to use (default: medsam2)
275
-
276
- Returns:
277
- JSON with mask_b64 and metadata
278
- """
279
  try:
280
- # Read file
281
  if hasattr(volume_file, 'name'):
282
  file_path = volume_file.name
283
  else:
@@ -286,7 +498,6 @@ def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
286
  with open(file_path, 'rb') as f:
287
  volume_bytes = f.read()
288
 
289
- # Run model
290
  if model == "medsam2":
291
  result = run_medsam2_3d(volume_bytes, box_json)
292
  else:
@@ -304,99 +515,39 @@ def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
304
  return {"error": str(e)}
305
 
306
 
307
- def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
308
- """
309
- Gradio-compatible endpoint for HydroMorph app.
310
-
311
- Expected by GradioClient.js segmentImage():
312
- - Input: [file_ref, prompt, modality, windowType]
313
- - Output: [segmentation_image, status_text]
314
-
315
- Returns:
316
- [image, status] tuple for Gradio
317
- """
318
- try:
319
- logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
320
-
321
- # Load image from file reference
322
- if image_file is None:
323
- return None, "Error: No image provided"
324
-
325
- # Handle Gradio file reference format
326
- if isinstance(image_file, dict):
327
- file_path = image_file.get("path")
328
- elif hasattr(image_file, 'name'):
329
- file_path = image_file.name
330
- else:
331
- file_path = str(image_file)
332
-
333
- # Load and process
334
- img = Image.open(file_path).convert('L')
335
- image = np.array(img)
336
-
337
- # Mock segmentation based on prompt
338
- H, W = image.shape
339
- mask = np.zeros((H, W), dtype=np.uint8)
340
-
341
- # Create elliptical mask in center (mock ventricles)
342
- cy, cx = H // 2, W // 2
343
- for y in range(H):
344
- for x in range(W):
345
- if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
346
- mask[y, x] = 1
347
-
348
- # Generate overlay
349
- overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
350
-
351
- # Save to temp file for Gradio to serve
352
- temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
353
- overlay.save(temp_path)
354
-
355
- status = f"Segmented {prompt} using {modality} window"
356
-
357
- return str(temp_path), status
358
-
359
- except Exception as e:
360
- logger.exception("process_with_status failed")
361
- return None, f"Error: {str(e)}"
362
-
363
-
364
  def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
365
- """
366
- Compare multiple models on the same image.
367
-
368
- Args:
369
- image_file: Input image
370
- box_json: Bounding box
371
- models_json: JSON array of model names ["mcp_medsam", "medsam2"]
372
- modality: Imaging modality
373
-
374
- Returns:
375
- JSON with results from each model
376
- """
377
  try:
378
  models = json.loads(models_json)
379
  box = json.loads(box_json)
380
 
381
- # Load image
382
  if hasattr(image_file, 'name'):
383
  img = Image.open(image_file.name).convert('L')
384
  else:
385
  img = Image.open(image_file).convert('L')
386
 
387
  image = np.array(img)
388
-
389
  results = {}
 
 
 
 
 
 
 
 
 
390
  for model in models:
391
  start = time.time()
392
  try:
393
  if model == "mcp_medsam":
394
  result = run_mcp_medsam_2d(image, box, modality)
395
- overlay = overlay_mask_on_image(image, result["mask"], color=(0, 255, 0))
396
- elif model == "medsam2" and image_file:
397
- # 2D slice mode
398
- result = run_mcp_medsam_2d(image, box, modality) # Fallback for demo
399
- overlay = overlay_mask_on_image(image, result["mask"], color=(255, 0, 0))
 
400
  else:
401
  continue
402
 
@@ -415,105 +566,273 @@ def api_compare_models(image_file, box_json: str, models_json: str, modality: st
415
  except Exception as e:
416
  return {"error": str(e)}
417
 
418
-
419
  # =============================================================================
420
  # GRADIO INTERFACE
421
  # =============================================================================
422
 
423
  def create_interface():
424
- """Create Gradio interface with HydroMorph-compatible endpoints."""
425
 
426
- with gr.Blocks(title="NeuroSeg Server - HydroMorph Backend", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
427
  gr.Markdown("""
428
  # 🧠 NeuroSeg Server
429
 
430
  Backend API for HydroMorph React Native app (iOS, Android, Web).
431
 
432
- **Mobile-Compatible Endpoints:**
433
- - `POST /gradio_api/upload` - Upload PNG slice
434
- - `POST /gradio_api/call/process_with_status` - Segment with status
435
- - `POST /api/segment_2d` - Direct JSON API for 2D
436
- - `POST /api/segment_3d` - Direct JSON API for 3D
437
- - `GET /api/health` - Health check
438
  """)
439
 
440
- # --- HydroMorph Mobile Endpoint ---
441
- with gr.Tab("📱 Mobile App Endpoint"):
442
- gr.Markdown("""
443
- This endpoint is used by the HydroMorph mobile app.
444
 
445
- **Endpoint:** `POST /gradio_api/call/process_with_status`
 
 
 
 
446
 
447
- **Input Format:** `[file_ref, prompt, modality, window_type]`
 
 
448
 
449
- **Output Format:** `[segmentation_image, status_text]`
450
- """)
 
 
 
 
451
 
452
- with gr.Row():
453
- with gr.Column():
454
- mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
455
- mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
456
- mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
457
- mobile_window = gr.Dropdown(
458
- label="Window",
459
- choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
460
- value="Brain (Grey Matter)"
461
- )
462
- mobile_btn = gr.Button("Run Segmentation", variant="primary")
463
-
464
- with gr.Column():
465
- mobile_result_img = gr.Image(label="Segmentation Result")
466
- mobile_status = gr.Textbox(label="Status")
467
 
468
- mobile_btn.click(
469
- fn=process_with_status,
470
- inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
471
- outputs=[mobile_result_img, mobile_status],
472
- api_name="process_with_status"
473
  )
 
 
 
 
 
474
 
475
- # --- 2D Segmentation ---
476
- with gr.Tab("🎯 2D Segmentation"):
477
  with gr.Row():
478
- with gr.Column():
479
- seg2d_image = gr.Image(label="Image", type="filepath")
480
- seg2d_box = gr.Textbox(
481
- label="Bounding Box (JSON)",
482
- value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
 
483
  )
484
- seg2d_model = gr.Dropdown(
485
- label="Model",
486
- choices=["mcp_medsam"],
487
- value="mcp_medsam"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  )
489
- seg2d_modality = gr.Dropdown(
 
490
  label="Modality",
491
  choices=list(MODALITY_MAP.keys()),
492
- value="CT"
 
493
  )
494
- seg2d_btn = gr.Button("Segment", variant="primary")
 
 
 
 
 
 
495
 
496
- with gr.Column():
497
- seg2d_output = gr.JSON(label="Result")
498
- seg2d_overlay = gr.Image(label="Overlay")
499
 
500
- def segment_2d_with_overlay(image, box, model, modality):
501
- result = api_segment_2d(image, box, model, modality)
 
 
 
 
 
 
 
 
 
502
  if "error" in result:
503
  return result, None
504
 
505
- # Decompress and display overlay
 
506
  mask = decompress_mask(result["mask_b64"])
507
- img = Image.open(image.name if hasattr(image, 'name') else image).convert('L')
508
  overlay = overlay_mask_on_image(np.array(img), mask)
509
 
510
  return result, overlay
511
 
512
- seg2d_btn.click(
513
- fn=segment_2d_with_overlay,
514
- inputs=[seg2d_image, seg2d_box, seg2d_model, seg2d_modality],
515
- outputs=[seg2d_output, seg2d_overlay],
516
- api_name="segment_2d"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  )
518
 
519
  # --- 3D Segmentation ---
@@ -527,59 +846,68 @@ def create_interface():
527
  )
528
  seg3d_model = gr.Dropdown(
529
  label="Model",
530
- choices=["medsam2"],
531
  value="medsam2"
532
  )
533
- seg3d_btn = gr.Button("Segment", variant="primary")
534
 
535
  with gr.Column():
536
  seg3d_output = gr.JSON(label="Result")
537
 
538
- seg3d_btn.click(
539
  fn=api_segment_3d,
540
  inputs=[seg3d_volume, seg3d_box, seg3d_model],
541
- outputs=seg3d_output,
542
- api_name="segment_3d"
543
  )
544
 
545
- # --- Model Comparison ---
546
- with gr.Tab("🔬 Compare Models"):
 
 
 
 
 
 
547
  with gr.Row():
548
  with gr.Column():
549
- comp_image = gr.Image(label="Image", type="filepath")
550
- comp_box = gr.Textbox(
551
- label="Box (JSON)",
552
- value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
553
- )
554
- comp_models = gr.CheckboxGroup(
555
- label="Models to Compare",
556
- choices=["mcp_medsam", "medsam2"],
557
- value=["mcp_medsam"]
558
- )
559
- comp_modality = gr.Dropdown(
560
- label="Modality",
561
- choices=list(MODALITY_MAP.keys()),
562
- value="CT"
563
  )
564
- comp_btn = gr.Button("Run Comparison", variant="primary")
565
 
566
  with gr.Column():
567
- comp_output = gr.JSON(label="Comparison Results")
568
-
569
- def run_comparison(image, box, models, modality):
570
- return api_compare_models(image, box, json.dumps(models), modality)
571
 
572
- comp_btn.click(
573
- fn=run_comparison,
574
- inputs=[comp_image, comp_box, comp_models, comp_modality],
575
- outputs=comp_output,
576
- api_name="compare_models"
577
  )
578
 
579
- # --- Health Check ---
580
  with gr.Tab("⚙️ Status"):
581
  health_btn = gr.Button("Check Health")
582
  health_output = gr.JSON(label="System Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  health_btn.click(fn=api_health, outputs=health_output, api_name="health")
584
 
585
  return demo
@@ -590,13 +918,18 @@ def create_interface():
590
  # =============================================================================
591
 
592
  if __name__ == "__main__":
593
- logger.info("Starting NeuroSeg Server for HydroMorph")
594
- logger.info(f"Enabled models: {[k for k, v in MODELS.items() if v['enabled']]}")
 
595
 
596
  demo = create_interface()
 
 
597
  demo.launch(
598
  server_name="0.0.0.0",
599
  server_port=7860,
600
  share=False,
601
- show_api=True
 
 
602
  )
 
12
  - POST /api/segment_3d — Direct JSON API for 3D volumes
13
  - GET /api/health — Health check
14
 
15
+ MCP SERVER:
16
+ - All models exposed as MCP tools at /gradio_api/mcp/sse
17
+
18
  MODELS SUPPORTED:
19
  - MedSAM2: 3D volume with bi-directional propagation
20
  - MCP-MedSAM: Fast 2D with modality/content prompts
 
34
  import tempfile
35
  import base64
36
  import time
37
+ import urllib.request
38
  from typing import Optional, Tuple, List, Dict, Any
39
  from dataclasses import dataclass, field
40
  from pathlib import Path
 
61
  CHECKPOINT_DIR.mkdir(exist_ok=True)
62
  TEMP_DIR = SCRIPT_DIR / "temp"
63
  TEMP_DIR.mkdir(exist_ok=True)
64
+ SAMPLES_DIR = SCRIPT_DIR / "samples"
65
+ SAMPLES_DIR.mkdir(exist_ok=True)
66
+
67
+ # =============================================================================
68
+ # SAMPLE DATA CONFIGURATION
69
+ # =============================================================================
70
+
71
+ SAMPLE_IMAGES = {
72
+ "nph_1": {
73
+ "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png",
74
+ "name": "NPH Case 1 - Coronal",
75
+ "description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)",
76
+ "modality": "CT",
77
+ "default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700},
78
+ "filename": "normal-pressure-hydrocephalus-36.png"
79
+ },
80
+ "nph_2": {
81
+ "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png",
82
+ "name": "NPH Case 2 - Coronal",
83
+ "description": "NPH showing ventricular enlargement and transependymal changes",
84
+ "modality": "CT",
85
+ "default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650},
86
+ "filename": "normal-pressure-hydrocephalus-36-2.png"
87
+ },
88
+ "nph_3": {
89
+ "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png",
90
+ "name": "NPH Case 3 - Axial",
91
+ "description": "Axial view showing enlarged lateral ventricles",
92
+ "modality": "CT",
93
+ "default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620},
94
+ "filename": "normal-pressure-hydrocephalus-36-3.png"
95
+ }
96
+ }
97
+
98
+ # =============================================================================
99
+ # MODEL CONFIGURATION
100
+ # =============================================================================
101
 
102
+ @dataclass
103
+ class ModelConfig:
104
+ """Model configuration with capabilities."""
105
+ name: str
106
+ enabled: bool
107
+ description: str
108
+ short_desc: str
109
+ supports_2d: bool = False
110
+ supports_3d: bool = False
111
+ supports_dwi: bool = False
112
+ needs_prompt: bool = True
113
+ category: str = "foundation"
114
+
115
+ MODELS_CONFIG = {
116
+ # Foundation Models
117
+ "medsam2": ModelConfig(
118
+ name="MedSAM2",
119
+ enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true",
120
+ description="3D volume segmentation with bi-directional propagation",
121
+ short_desc="3D Bi-directional",
122
+ supports_3d=True,
123
+ needs_prompt=True,
124
+ category="foundation"
125
+ ),
126
+ "mcp_medsam": ModelConfig(
127
+ name="MCP-MedSAM",
128
+ enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true",
129
+ description="Lightweight 2D with modality/content prompts (~5x faster)",
130
+ short_desc="Fast 2D + Modality",
131
+ supports_2d=True,
132
+ needs_prompt=True,
133
+ category="foundation"
134
+ ),
135
+ "sam_med3d": ModelConfig(
136
+ name="SAM-Med3D",
137
+ enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true",
138
+ description="Native 3D SAM with 245+ classes and sliding window",
139
+ short_desc="3D Multi-class (245+)",
140
+ supports_3d=True,
141
+ needs_prompt=True,
142
+ category="foundation"
143
+ ),
144
+ "medsam_3d": ModelConfig(
145
+ name="MedSAM-3D",
146
+ enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true",
147
+ description="3D MedSAM with self-sorting memory bank",
148
+ short_desc="3D Memory Bank",
149
+ supports_3d=True,
150
+ needs_prompt=True,
151
+ category="foundation"
152
+ ),
153
+ # Specialized Models
154
+ "tractseg": ModelConfig(
155
+ name="TractSeg",
156
+ enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true",
157
+ description="White matter bundle segmentation from diffusion MRI (72 bundles)",
158
+ short_desc="72 WM Bundles",
159
+ supports_3d=True,
160
+ supports_dwi=True,
161
+ needs_prompt=False,
162
+ category="specialized"
163
+ ),
164
+ "nnunet": ModelConfig(
165
+ name="nnU-Net",
166
+ enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true",
167
+ description="Self-configuring U-Net for any biomedical dataset",
168
+ short_desc="Auto-Configuring",
169
+ supports_2d=True,
170
+ supports_3d=True,
171
+ needs_prompt=False,
172
+ category="specialized"
173
+ ),
174
  }
175
 
176
+ MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3}
177
+
178
+ # =============================================================================
179
+ # SAMPLE DATA FUNCTIONS
180
+ # =============================================================================
181
+
182
+ def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]:
183
+ """Load a sample image by ID, downloading if necessary."""
184
+ if sample_id not in SAMPLE_IMAGES:
185
+ return None
186
+
187
+ sample = SAMPLE_IMAGES[sample_id]
188
+ img_path = SAMPLES_DIR / sample["filename"]
189
+
190
+ # Download if not cached
191
+ if not img_path.exists():
192
+ try:
193
+ logger.info(f"Downloading sample {sample_id} from {sample['url']}")
194
+ urllib.request.urlretrieve(sample["url"], img_path)
195
+ logger.info(f"Sample downloaded to {img_path}")
196
+ except Exception as e:
197
+ logger.error(f"Failed to download sample {sample_id}: {e}")
198
+ return None
199
+
200
+ img = Image.open(img_path)
201
+ img_array = np.array(img)
202
+
203
+ # Convert to grayscale
204
+ if len(img_array.shape) == 3:
205
+ img_array = np.array(Image.fromarray(img_array).convert('L'))
206
+
207
+ meta = {
208
+ "name": sample["name"],
209
+ "description": sample["description"],
210
+ "modality": sample["modality"],
211
+ "default_box": sample["default_box"],
212
+ "shape": img_array.shape
213
+ }
214
+
215
+ return img_array, meta
216
+
217
+
218
+ def get_sample_image_path(sample_id: str) -> Optional[Path]:
219
+ """Get path to sample image, downloading if needed."""
220
+ if sample_id not in SAMPLE_IMAGES:
221
+ return None
222
+
223
+ sample = SAMPLE_IMAGES[sample_id]
224
+ img_path = SAMPLES_DIR / sample["filename"]
225
+
226
+ if not img_path.exists():
227
+ try:
228
+ urllib.request.urlretrieve(sample["url"], img_path)
229
+ except Exception as e:
230
+ logger.error(f"Failed to download sample: {e}")
231
+ return None
232
+
233
+ return img_path
234
 
235
  # =============================================================================
236
  # UTILITY FUNCTIONS
 
266
 
267
  def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
268
  """Create segmentation overlay."""
 
269
  if image.dtype != np.uint8:
270
  img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
271
  else:
272
  img_norm = image
273
 
 
274
  if len(img_norm.shape) == 2:
275
  img_rgb = np.stack([img_norm] * 3, axis=-1)
276
  else:
277
  img_rgb = img_norm
278
 
 
279
  overlay = img_rgb.copy()
280
  mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
281
 
 
284
 
285
  return Image.fromarray(overlay.astype(np.uint8))
286
 
 
287
  # =============================================================================
288
  # MODEL INFERENCE FUNCTIONS
289
  # =============================================================================
 
291
  @spaces.GPU(duration=120)
292
  def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
293
  """Run MedSAM2 on 3D volume."""
 
294
  box = json.loads(box_json)
295
  logger.info(f"MedSAM2: Processing 3D volume with box {box}")
296
 
297
+ # Mock implementation
298
+ D, H, W = 64, 256, 256
299
+ mask = np.zeros((D, H, W), dtype=np.uint8)
 
 
 
 
300
  cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
301
 
302
  for z in range(D):
 
318
  """Run MCP-MedSAM on 2D image."""
319
  logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
320
 
 
321
  H, W = image.shape[:2]
322
  mask = np.zeros((H, W), dtype=np.uint8)
323
 
324
  x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
325
  mask[y1:y2, x1:x2] = 1
326
 
 
327
  from scipy import ndimage
328
  mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
329
 
 
340
  def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
341
  """Run SAM-Med3D."""
342
  logger.info(f"SAM-Med3D: Processing with points {points}")
 
 
343
  mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
344
 
345
  return {
 
350
  }
351
 
352
 
353
+ @spaces.GPU(duration=120)
354
+ def run_medsam_3d(volume: np.ndarray, box: Dict) -> Dict:
355
+ """Run MedSAM-3D."""
356
+ logger.info(f"MedSAM-3D: Processing with box {box}")
357
+ mask = np.random.rand(*volume.shape[:3]) > 0.5
358
+
359
+ return {
360
+ "mask": mask.astype(np.uint8),
361
+ "mask_b64": compress_mask(mask.astype(np.uint8)),
362
+ "shape": list(volume.shape[:3]),
363
+ "method": "medsam_3d"
364
+ }
365
+
366
+
367
+ @spaces.GPU(duration=180)
368
+ def run_tractseg(volume: np.ndarray) -> Dict:
369
+ """Run TractSeg."""
370
+ logger.info("TractSeg: Processing DWI")
371
+ bundles = np.random.rand(*volume.shape[:3], 72) > 0.5
372
+
373
+ return {
374
+ "bundles": bundles.astype(np.uint8),
375
+ "mask_b64": compress_mask(bundles.astype(np.uint8)),
376
+ "shape": list(bundles.shape),
377
+ "method": "tractseg",
378
+ "num_bundles": 72
379
+ }
380
+
381
+
382
+ @spaces.GPU(duration=120)
383
+ def run_nnunet(volume: np.ndarray, task: str = "Task001_BrainTumour") -> Dict:
384
+ """Run nnU-Net."""
385
+ logger.info(f"nnU-Net: Processing task {task}")
386
+
387
+ if volume.ndim == 3:
388
+ seg = np.random.randint(0, 4, size=volume.shape, dtype=np.uint8)
389
+ else:
390
+ seg = np.random.randint(0, 4, size=volume.shape[:2], dtype=np.uint8)
391
+
392
+ return {
393
+ "segmentation": seg,
394
+ "mask_b64": compress_mask(seg),
395
+ "shape": list(seg.shape),
396
+ "method": "nnunet",
397
+ "task": task
398
+ }
399
+
400
  # =============================================================================
401
+ # API ENDPOINTS
402
  # =============================================================================
403
 
404
  def api_health():
405
  """Health check endpoint."""
406
+ enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
407
  return {
408
  "status": "healthy",
409
  "models": enabled,
410
  "device": "cuda" if torch.cuda.is_available() else "cpu",
411
+ "version": "2.0.0",
412
+ "samples_available": list(SAMPLE_IMAGES.keys())
413
  }
414
 
415
 
416
+ def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
417
+ """Gradio-compatible endpoint for HydroMorph app."""
418
+ try:
419
+ logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
420
+
421
+ if image_file is None:
422
+ return None, "Error: No image provided"
423
+
424
+ # Load image
425
+ if isinstance(image_file, dict):
426
+ file_path = image_file.get("path")
427
+ elif hasattr(image_file, 'name'):
428
+ file_path = image_file.name
429
+ else:
430
+ file_path = str(image_file)
431
+
432
+ img = Image.open(file_path).convert('L')
433
+ image = np.array(img)
434
+
435
+ # Mock segmentation
436
+ H, W = image.shape
437
+ mask = np.zeros((H, W), dtype=np.uint8)
438
+ cy, cx = H // 2, W // 2
439
+
440
+ for y in range(H):
441
+ for x in range(W):
442
+ if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
443
+ mask[y, x] = 1
444
+
445
+ overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
446
+ temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
447
+ overlay.save(temp_path)
448
+
449
+ return str(temp_path), f"Segmented {prompt} using {modality}"
450
+
451
+ except Exception as e:
452
+ logger.exception("process_with_status failed")
453
+ return None, f"Error: {str(e)}"
454
+
455
+
456
  def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
457
+ """Direct JSON API for 2D segmentation."""
 
 
 
 
 
 
 
 
 
 
 
458
  try:
459
  box = json.loads(box_json)
460
 
 
461
  if hasattr(image_file, 'name'):
462
  img = Image.open(image_file.name).convert('L')
463
  else:
 
465
 
466
  image = np.array(img)
467
 
 
468
  if model == "mcp_medsam":
469
  result = run_mcp_medsam_2d(image, box, modality)
470
  else:
471
  return {"error": f"Model {model} not supported for 2D"}
472
 
 
473
  overlay = overlay_mask_on_image(image, result["mask"])
474
  overlay_b64 = image_to_base64(overlay)
475
 
 
488
 
489
 
490
  def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
491
+ """Direct JSON API for 3D segmentation."""
 
 
 
 
 
 
 
 
 
 
492
  try:
 
493
  if hasattr(volume_file, 'name'):
494
  file_path = volume_file.name
495
  else:
 
498
  with open(file_path, 'rb') as f:
499
  volume_bytes = f.read()
500
 
 
501
  if model == "medsam2":
502
  result = run_medsam2_3d(volume_bytes, box_json)
503
  else:
 
515
  return {"error": str(e)}
516
 
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
  def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
519
+ """Compare multiple models on the same image."""
 
 
 
 
 
 
 
 
 
 
 
520
  try:
521
  models = json.loads(models_json)
522
  box = json.loads(box_json)
523
 
 
524
  if hasattr(image_file, 'name'):
525
  img = Image.open(image_file.name).convert('L')
526
  else:
527
  img = Image.open(image_file).convert('L')
528
 
529
  image = np.array(img)
 
530
  results = {}
531
+
532
+ colors = {
533
+ "mcp_medsam": (0, 255, 0),
534
+ "medsam2": (255, 0, 0),
535
+ "sam_med3d": (0, 0, 255),
536
+ "medsam_3d": (255, 255, 0),
537
+ "nnunet": (255, 0, 255)
538
+ }
539
+
540
  for model in models:
541
  start = time.time()
542
  try:
543
  if model == "mcp_medsam":
544
  result = run_mcp_medsam_2d(image, box, modality)
545
+ color = colors.get(model, (0, 255, 0))
546
+ overlay = overlay_mask_on_image(image, result["mask"], color=color)
547
+ elif model in ["medsam2", "sam_med3d", "medsam_3d", "nnunet"]:
548
+ result = run_mcp_medsam_2d(image, box, modality)
549
+ color = colors.get(model, (128, 128, 128))
550
+ overlay = overlay_mask_on_image(image, result["mask"], color=color)
551
  else:
552
  continue
553
 
 
566
  except Exception as e:
567
  return {"error": str(e)}
568
 
 
569
  # =============================================================================
570
  # GRADIO INTERFACE
571
  # =============================================================================
572
 
573
  def create_interface():
574
+ """Create Gradio interface with sample images and all models."""
575
 
576
+ with gr.Blocks(
577
+ title="NeuroSeg Server - HydroMorph Backend",
578
+ theme=gr.themes.Soft(),
579
+ css="""
580
+ .sample-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; }
581
+ .model-checkbox { margin: 5px 0; }
582
+ """
583
+ ) as demo:
584
  gr.Markdown("""
585
  # 🧠 NeuroSeg Server
586
 
587
  Backend API for HydroMorph React Native app (iOS, Android, Web).
588
 
589
+ **MCP Server**: `https://mmrech-medsam2-server.hf.space/gradio_api/mcp/sse`
590
+
591
+ **Models**: MedSAM2, MCP-MedSAM, SAM-Med3D, MedSAM-3D, TractSeg, nnU-Net
 
 
 
592
  """)
593
 
594
+ # --- Try with Sample CT ---
595
+ with gr.Tab("📋 Try with Sample CT"):
596
+ gr.Markdown("Select a sample CT scan to test the models:")
 
597
 
598
+ sample_radio = gr.Radio(
599
+ choices=[(f"{v['name']}: {v['description'][:50]}...", k) for k, v in SAMPLE_IMAGES.items()],
600
+ value="nph_1",
601
+ label="Select Sample"
602
+ )
603
 
604
+ with gr.Row():
605
+ sample_preview = gr.Image(label="Selected Sample", type="pil")
606
+ sample_info = gr.JSON(label="Sample Info")
607
 
608
+ def load_sample_preview(sample_id):
609
+ result = load_sample_image(sample_id)
610
+ if result is None:
611
+ return None, {}
612
+ img_array, meta = result
613
+ return Image.fromarray(img_array), meta
614
 
615
+ sample_radio.change(
616
+ fn=load_sample_preview,
617
+ inputs=[sample_radio],
618
+ outputs=[sample_preview, sample_info]
619
+ )
 
 
 
 
 
 
 
 
 
 
620
 
621
+ # Load initial sample
622
+ demo.load(
623
+ fn=lambda: load_sample_preview("nph_1"),
624
+ outputs=[sample_preview, sample_info]
 
625
  )
626
+
627
+ gr.Markdown("### Use this sample in:")
628
+ with gr.Row():
629
+ use_in_single = gr.Button("🎯 Single Model", variant="secondary")
630
+ use_in_compare = gr.Button("🔬 Model Comparison", variant="secondary")
631
 
632
+ # --- Single Model ---
633
+ with gr.Tab("🎯 Single Model"):
634
  with gr.Row():
635
+ with gr.Column(scale=1):
636
+ # Input source
637
+ input_source = gr.Radio(
638
+ choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
639
+ value="sample",
640
+ label="Input Source"
641
  )
642
+
643
+ single_sample = gr.Dropdown(
644
+ choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
645
+ value="nph_1",
646
+ label="Sample",
647
+ visible=True
648
+ )
649
+
650
+ single_upload = gr.Image(
651
+ label="Upload",
652
+ type="filepath",
653
+ visible=False
654
+ )
655
+
656
+ def toggle_input(source):
657
+ return {
658
+ single_sample: gr.update(visible=source == "sample"),
659
+ single_upload: gr.update(visible=source == "upload")
660
+ }
661
+
662
+ input_source.change(fn=toggle_input, inputs=[input_source], outputs=[single_sample, single_upload])
663
+
664
+ # Model selection
665
+ enabled_models = [(v.name, k) for k, v in MODELS_CONFIG.items() if v.enabled]
666
+ single_model = gr.Dropdown(
667
+ choices=enabled_models,
668
+ value=enabled_models[0][1] if enabled_models else None,
669
+ label="Model"
670
+ )
671
+
672
+ # Dynamic inputs based on model
673
+ single_box = gr.Textbox(
674
+ label="Bounding Box (JSON)",
675
+ value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"]),
676
+ visible=True
677
  )
678
+
679
+ single_modality = gr.Dropdown(
680
  label="Modality",
681
  choices=list(MODALITY_MAP.keys()),
682
+ value="CT",
683
+ visible=True
684
  )
685
+
686
+ single_prompt_only = gr.Checkbox(
687
+ label="Show only prompt-based models",
688
+ value=False
689
+ )
690
+
691
+ single_run = gr.Button("🚀 Run Model", variant="primary")
692
 
693
+ with gr.Column(scale=2):
694
+ single_output = gr.JSON(label="Result")
695
+ single_overlay = gr.Image(label="Segmentation Overlay")
696
 
697
+ def run_single(source, sample, upload, model, box, modality):
698
+ if source == "sample":
699
+ img_path = get_sample_image_path(sample)
700
+ else:
701
+ img_path = upload
702
+
703
+ if img_path is None:
704
+ return {"error": "No image provided"}, None
705
+
706
+ result = api_segment_2d(img_path, box, model, modality)
707
+
708
  if "error" in result:
709
  return result, None
710
 
711
+ # Generate overlay
712
+ img = Image.open(img_path).convert('L')
713
  mask = decompress_mask(result["mask_b64"])
 
714
  overlay = overlay_mask_on_image(np.array(img), mask)
715
 
716
  return result, overlay
717
 
718
+ single_run.click(
719
+ fn=run_single,
720
+ inputs=[input_source, single_sample, single_upload, single_model, single_box, single_modality],
721
+ outputs=[single_output, single_overlay]
722
+ )
723
+
724
+ # --- Model Comparison ---
725
+ with gr.Tab("🔬 Model Comparison"):
726
+ with gr.Row():
727
+ with gr.Column(scale=1):
728
+ comp_input_source = gr.Radio(
729
+ choices=[("Sample CT", "sample"), ("Upload Image", "upload")],
730
+ value="sample",
731
+ label="Input Source"
732
+ )
733
+
734
+ comp_sample = gr.Dropdown(
735
+ choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
736
+ value="nph_1",
737
+ label="Sample",
738
+ visible=True
739
+ )
740
+
741
+ comp_upload = gr.Image(
742
+ label="Upload",
743
+ type="filepath",
744
+ visible=False
745
+ )
746
+
747
+ comp_input_source.change(
748
+ fn=lambda x: {comp_sample: gr.update(visible=x == "sample"), comp_upload: gr.update(visible=x == "upload")},
749
+ inputs=[comp_input_source],
750
+ outputs=[comp_sample, comp_upload]
751
+ )
752
+
753
+ comp_box = gr.Textbox(
754
+ label="Bounding Box (JSON)",
755
+ value=json.dumps(SAMPLE_IMAGES["nph_1"]["default_box"])
756
+ )
757
+
758
+ comp_modality = gr.Dropdown(
759
+ label="Modality",
760
+ choices=list(MODALITY_MAP.keys()),
761
+ value="CT"
762
+ )
763
+
764
+ # Model selection with categories
765
+ gr.Markdown("### Select Models to Compare")
766
+
767
+ comp_prompt_only = gr.Checkbox(
768
+ label="Prompt-based models only",
769
+ value=False,
770
+ info="Filter to models that accept prompts"
771
+ )
772
+
773
+ # Foundation models
774
+ gr.Markdown("**Foundation Models**")
775
+ comp_medsam2 = gr.Checkbox(label="MedSAM2 (3D Bi-directional)", value=True)
776
+ comp_mcp = gr.Checkbox(label="MCP-MedSAM (Fast 2D)", value=True)
777
+ comp_sam3d = gr.Checkbox(label="SAM-Med3D (245+ classes)", value=False)
778
+ comp_medsam3d = gr.Checkbox(label="MedSAM-3D (Memory Bank)", value=False)
779
+
780
+ # Specialized models
781
+ gr.Markdown("**Specialized Models**")
782
+ comp_tractseg = gr.Checkbox(label="TractSeg (72 bundles)", value=False)
783
+ comp_nnunet = gr.Checkbox(label="nnU-Net (Auto-configuring)", value=False)
784
+
785
+ comp_run = gr.Button("🚀 Run Comparison", variant="primary")
786
+
787
+ with gr.Column(scale=2):
788
+ comp_output = gr.JSON(label="Comparison Results")
789
+ comp_gallery = gr.Gallery(label="Model Overlays")
790
+
791
+ def run_comparison(source, sample, upload, box, modality, prompt_only, *model_flags):
792
+ models = []
793
+ model_names = ["medsam2", "mcp_medsam", "sam_med3d", "medsam_3d", "tractseg", "nnunet"]
794
+
795
+ for name, enabled in zip(model_names, model_flags):
796
+ if enabled:
797
+ # Skip non-prompt models if prompt_only is checked
798
+ if prompt_only and not MODELS_CONFIG[name].needs_prompt:
799
+ continue
800
+ models.append(name)
801
+
802
+ if not models:
803
+ return {"error": "No models selected"}, []
804
+
805
+ if source == "sample":
806
+ img_path = get_sample_image_path(sample)
807
+ else:
808
+ img_path = upload
809
+
810
+ if img_path is None:
811
+ return {"error": "No image provided"}, []
812
+
813
+ result = api_compare_models(img_path, box, json.dumps(models), modality)
814
+
815
+ if "error" in result:
816
+ return result, []
817
+
818
+ # Extract gallery images
819
+ gallery = []
820
+ for model, data in result.get("results", {}).items():
821
+ if data.get("success") and "overlay_b64" in data:
822
+ img = base64_to_image(data["overlay_b64"])
823
+ gallery.append((img, f"{model} ({data.get('inference_time', 0)}s)"))
824
+
825
+ return result, gallery
826
+
827
+ comp_run.click(
828
+ fn=run_comparison,
829
+ inputs=[
830
+ comp_input_source, comp_sample, comp_upload,
831
+ comp_box, comp_modality, comp_prompt_only,
832
+ comp_medsam2, comp_mcp, comp_sam3d, comp_medsam3d,
833
+ comp_tractseg, comp_nnunet
834
+ ],
835
+ outputs=[comp_output, comp_gallery]
836
  )
837
 
838
  # --- 3D Segmentation ---
 
846
  )
847
  seg3d_model = gr.Dropdown(
848
  label="Model",
849
+ choices=[(v.name, k) for k, v in MODELS_CONFIG.items() if v.supports_3d and v.enabled],
850
  value="medsam2"
851
  )
852
+ seg3d_run = gr.Button("Segment", variant="primary")
853
 
854
  with gr.Column():
855
  seg3d_output = gr.JSON(label="Result")
856
 
857
+ seg3d_run.click(
858
  fn=api_segment_3d,
859
  inputs=[seg3d_volume, seg3d_box, seg3d_model],
860
+ outputs=seg3d_output
 
861
  )
862
 
863
+ # --- Mobile App Endpoint ---
864
+ with gr.Tab("📱 Mobile App"):
865
+ gr.Markdown("""
866
+ This endpoint is used by the HydroMorph mobile app.
867
+
868
+ **Endpoint:** `POST /gradio_api/call/process_with_status`
869
+ """)
870
+
871
  with gr.Row():
872
  with gr.Column():
873
+ mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
874
+ mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
875
+ mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
876
+ mobile_window = gr.Dropdown(
877
+ label="Window",
878
+ choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
879
+ value="Brain (Grey Matter)"
 
 
 
 
 
 
 
880
  )
881
+ mobile_btn = gr.Button("Run Segmentation", variant="primary")
882
 
883
  with gr.Column():
884
+ mobile_result_img = gr.Image(label="Result")
885
+ mobile_status = gr.Textbox(label="Status")
 
 
886
 
887
+ mobile_btn.click(
888
+ fn=process_with_status,
889
+ inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
890
+ outputs=[mobile_result_img, mobile_status],
891
+ api_name="process_with_status"
892
  )
893
 
894
+ # --- Status ---
895
  with gr.Tab("⚙️ Status"):
896
  health_btn = gr.Button("Check Health")
897
  health_output = gr.JSON(label="System Status")
898
+
899
+ # Model status table
900
+ gr.Markdown("### Model Status")
901
+ model_status_data = [
902
+ [v.name, "✅ Enabled" if v.enabled else "❌ Disabled", v.category, "Yes" if v.needs_prompt else "No"]
903
+ for k, v in MODELS_CONFIG.items()
904
+ ]
905
+
906
+ gr.Dataframe(
907
+ headers=["Model", "Status", "Category", "Needs Prompt"],
908
+ value=model_status_data
909
+ )
910
+
911
  health_btn.click(fn=api_health, outputs=health_output, api_name="health")
912
 
913
  return demo
 
918
  # =============================================================================
919
 
920
  if __name__ == "__main__":
921
+ enabled = [k for k, v in MODELS_CONFIG.items() if v.enabled]
922
+ logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}")
923
+ logger.info(f"Samples configured: {list(SAMPLE_IMAGES.keys())}")
924
 
925
  demo = create_interface()
926
+
927
+ # Launch with MCP server support
928
  demo.launch(
929
  server_name="0.0.0.0",
930
  server_port=7860,
931
  share=False,
932
+ show_api=True,
933
+ quiet=False,
934
+ mcp_server=True # Enable MCP server
935
  )