mmrech commited on
Commit
f87c868
·
unverified ·
1 Parent(s): 6a93c79

Refactor for HydroMorph mobile app compatibility

Browse files

- Add process_with_status endpoint for GradioClient.js compatibility
- Add api_segment_2d/3d for direct JSON API calls
- Add api_compare_models for multi-model comparison
- Add overlay_mask_on_image for result visualization
- Add compress_mask/decompress_mask utilities
- Add image_to_base64/base64_to_image utilities
- Implement run_medsam2_3d, run_mcp_medsam_2d, run_sam_med3d
- Update Gradio UI with mobile app, 2D, 3D, comparison tabs
- Ensure API returns [image, status] format expected by mobile app
- Add proper error handling and logging for mobile debugging

Files changed (3) hide show
  1. .DS_Store +0 -0
  2. __pycache__/app.cpython-313.pyc +0 -0
  3. app.py +487 -794
.DS_Store ADDED
Binary file (6.15 kB). View file
 
__pycache__/app.cpython-313.pyc CHANGED
Binary files a/__pycache__/app.cpython-313.pyc and b/__pycache__/app.cpython-313.pyc differ
 
app.py CHANGED
@@ -1,19 +1,24 @@
1
  """
2
- NeuroSeg Server — Advanced Multi-Model Medical Segmentation Platform
3
- ====================================================================
4
- A unified server with model comparison capabilities and sample data support.
 
 
 
 
 
 
 
 
 
5
 
6
  MODELS SUPPORTED:
7
- Foundation Models:
8
- - MedSAM2: 3D volume segmentation with bi-directional propagation
9
- - MCP-MedSAM: Lightweight 2D with modality/content prompts
10
- - SAM-Med3D: Native 3D SAM (245+ classes, sliding window)
11
- - MedSAM-3D: 3D MedSAM with self-sorting memory bank
12
-
13
- Specialized Models:
14
- - TractSeg: White matter bundle segmentation (72 bundles)
15
- - nnU-Net: Self-configuring U-Net for any biomedical dataset
16
- - NeuroSAM3: Advanced neuroimage segmentation (placeholder)
17
 
18
  Author: Matheus Machado Rech
19
  """
@@ -25,11 +30,11 @@ import logging
25
  import os
26
  import tempfile
27
  import base64
28
- import shutil
29
- from typing import Optional, Tuple, List, Dict, Any, Callable
30
  from dataclasses import dataclass, field
31
- from enum import Enum
32
  from pathlib import Path
 
33
 
34
  import gradio as gr
35
  import spaces
@@ -39,871 +44,559 @@ import torch.nn as nn
39
  import torch.nn.functional as F
40
  from PIL import Image, ImageDraw
41
  from huggingface_hub import hf_hub_download
42
-
43
  import nibabel as nib
44
  import scipy
45
 
46
- # ---------------------------------------------------------------------------
47
- # Logging
48
- # ---------------------------------------------------------------------------
49
-
50
- logging.basicConfig(
51
- level=logging.INFO,
52
- format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
53
- )
54
  logger = logging.getLogger("neuroseg_server")
55
 
56
- # ---------------------------------------------------------------------------
57
- # Configuration & Feature Flags
58
- # ---------------------------------------------------------------------------
59
-
60
  SCRIPT_DIR = Path(__file__).parent.resolve()
61
- SAMPLES_DIR = SCRIPT_DIR / "samples"
62
-
63
-
64
- @dataclass
65
- class ModelCapability:
66
- """Defines what a model can do."""
67
- needs_prompt: bool = True
68
- supports_2d: bool = False
69
- supports_3d: bool = False
70
- supports_dwi: bool = False
71
- supports_multiclass: bool = False
72
- supports_sliding_window: bool = False
73
- realtime: bool = False # Fast enough for interactive use
74
-
75
-
76
- @dataclass
77
- class ModelConfig:
78
- """Configuration for each supported model."""
79
- name: str
80
- enabled: bool
81
- description: str
82
- short_desc: str
83
- capabilities: ModelCapability
84
- preferred_formats: List[str]
85
- default_prompt: Dict = field(default_factory=dict)
86
- category: str = "foundation" # foundation, specialized, experimental
87
-
88
-
89
- # Feature flags - control which models are available
90
  MODELS = {
91
- # Foundation Models (SAM-based)
92
- "medsam2": ModelConfig(
93
- name="MedSAM2",
94
- enabled=os.getenv("ENABLE_MEDSAM2", "true").lower() == "true",
95
- description="3D volume segmentation with bi-directional propagation from a single annotated slice",
96
- short_desc="3D Bi-directional",
97
- capabilities=ModelCapability(
98
- needs_prompt=True,
99
- supports_2d=False,
100
- supports_3d=True,
101
- supports_dwi=False,
102
- realtime=False
103
- ),
104
- preferred_formats=[".npy", ".nii.gz"],
105
- default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220, "slice_idx": 32},
106
- category="foundation"
107
- ),
108
- "mcp_medsam": ModelConfig(
109
- name="MCP-MedSAM",
110
- enabled=os.getenv("ENABLE_MCP_MEDSAM", "true").lower() == "true",
111
- description="Lightweight 2D segmentation with explicit modality and content prompts (~5x faster)",
112
- short_desc="Fast 2D + Modality",
113
- capabilities=ModelCapability(
114
- needs_prompt=True,
115
- supports_2d=True,
116
- supports_3d=False,
117
- supports_dwi=False,
118
- realtime=True
119
- ),
120
- preferred_formats=[".png", ".jpg", ".jpeg", ".npy", ".nii.gz"],
121
- default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220},
122
- category="foundation"
123
- ),
124
- "sam_med3d": ModelConfig(
125
- name="SAM-Med3D",
126
- enabled=os.getenv("ENABLE_SAM_MED3D", "false").lower() == "true",
127
- description="Native 3D SAM with 245+ classes and sliding window for large volumes",
128
- short_desc="3D Multi-class (245+)",
129
- capabilities=ModelCapability(
130
- needs_prompt=True,
131
- supports_2d=False,
132
- supports_3d=True,
133
- supports_multiclass=True,
134
- supports_sliding_window=True,
135
- realtime=False
136
- ),
137
- preferred_formats=[".nii.gz"],
138
- default_prompt={"points": [[64, 64, 64]], "labels": [1]},
139
- category="foundation"
140
- ),
141
- "medsam_3d": ModelConfig(
142
- name="MedSAM-3D",
143
- enabled=os.getenv("ENABLE_MEDSAM_3D", "false").lower() == "true",
144
- description="3D MedSAM with self-sorting memory bank for consistent volumetric segmentation",
145
- short_desc="3D Memory Bank",
146
- capabilities=ModelCapability(
147
- needs_prompt=True,
148
- supports_2d=False,
149
- supports_3d=True,
150
- realtime=False
151
- ),
152
- preferred_formats=[".nii.gz", ".npy"],
153
- default_prompt={"x1": 100, "y1": 120, "x2": 200, "y2": 220, "slice_idx": 32},
154
- category="foundation"
155
- ),
156
- # Specialized Models
157
- "tractseg": ModelConfig(
158
- name="TractSeg",
159
- enabled=os.getenv("ENABLE_TRACTSEG", "true").lower() == "true",
160
- description="White matter bundle segmentation from diffusion MRI (72 bundles)",
161
- short_desc="72 WM Bundles",
162
- capabilities=ModelCapability(
163
- needs_prompt=False, # Fully automatic
164
- supports_2d=False,
165
- supports_3d=True,
166
- supports_dwi=True,
167
- supports_multiclass=True,
168
- realtime=False
169
- ),
170
- preferred_formats=[".nii.gz"],
171
- default_prompt={},
172
- category="specialized"
173
- ),
174
- "nnunet": ModelConfig(
175
- name="nnU-Net",
176
- enabled=os.getenv("ENABLE_NNUNET", "true").lower() == "true",
177
- description="Self-configuring U-Net that auto-tunes to any biomedical dataset (SOTA baseline)",
178
- short_desc="Auto-Configuring",
179
- capabilities=ModelCapability(
180
- needs_prompt=False, # Task-based
181
- supports_2d=True,
182
- supports_3d=True,
183
- supports_multiclass=True,
184
- realtime=False
185
- ),
186
- preferred_formats=[".nii.gz", ".npy"],
187
- default_prompt={"task": "Task001_BrainTumour"},
188
- category="specialized"
189
- ),
190
- "neurosam3": ModelConfig(
191
- name="NeuroSAM3",
192
- enabled=os.getenv("ENABLE_NEUROSAM3", "false").lower() == "true",
193
- description="Advanced neuroimage segmentation (pending configuration)",
194
- short_desc="Advanced (Pending)",
195
- capabilities=ModelCapability(
196
- needs_prompt=True,
197
- supports_2d=True,
198
- supports_3d=True,
199
- realtime=False
200
- ),
201
- preferred_formats=[".nii.gz", ".npy", ".png", ".jpg"],
202
- default_prompt={},
203
- category="experimental"
204
- ),
205
  }
206
 
207
- # Sample data configuration (URLs for download)
208
- SAMPLE_IMAGES = {
209
- "nph_1": {
210
- "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36.png",
211
- "name": "NPH Case 1 - Coronal",
212
- "description": "Normal Pressure Hydrocephalus with enlarged ventricles (coronal view)",
213
- "modality": "CT",
214
- "default_box": {"x1": 450, "y1": 350, "x2": 750, "y2": 700},
215
- "filename": "normal-pressure-hydrocephalus-36.png"
216
- },
217
- "nph_2": {
218
- "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-2.png",
219
- "name": "NPH Case 2 - Coronal",
220
- "description": "NPH showing ventricular enlargement and transependymal changes",
221
- "modality": "CT",
222
- "default_box": {"x1": 400, "y1": 300, "x2": 700, "y2": 650},
223
- "filename": "normal-pressure-hydrocephalus-36-2.png"
224
- },
225
- "nph_3": {
226
- "url": "https://huggingface.co/datasets/radimagenet/normal-pressure-hydrocephalus/resolve/main/normal-pressure-hydrocephalus-36-3.png",
227
- "name": "NPH Case 3 - Axial",
228
- "description": "Axial view showing enlarged lateral ventricles",
229
- "modality": "CT",
230
- "default_box": {"x1": 420, "y1": 380, "x2": 680, "y2": 620},
231
- "filename": "normal-pressure-hydrocephalus-36-3.png"
232
- }
233
- }
234
 
235
- # Checkpoint paths
236
- CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
237
- CHECKPOINT_DIR.mkdir(exist_ok=True)
238
 
239
- # Modality mapping
240
- MODALITY_MAP = {
241
- "CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3, "XRAY": 3,
242
- "Ultrasound": 4, "US": 4, "Mammography": 5, "OCT": 6,
243
- "Endoscopy": 7, "Fundus": 8, "Dermoscopy": 9, "Microscopy": 10,
244
- }
245
 
246
- # TractSeg bundles
247
- TRACTSEG_BUNDLES = [
248
- "AF_left", "AF_right", "ATR_left", "ATR_right", "CA",
249
- "CC_1", "CC_2", "CC_3", "CC_4", "CC_5", "CC_6", "CC_7",
250
- "CG_left", "CG_right", "CST_left", "CST_right", "MLF_left", "MLF_right",
251
- "FPT_left", "FPT_right", "FX_left", "FX_right", "ICP_left", "ICP_right",
252
- "IFO_left", "IFO_right", "ILF_left", "ILF_right", "MCP",
253
- "OR_left", "OR_right", "POPT_left", "POPT_right", "SCP_left", "SCP_right",
254
- "SLF_I_left", "SLF_I_right", "SLF_II_left", "SLF_II_right", "SLF_III_left", "SLF_III_right",
255
- "STR_left", "STR_right", "UF_left", "UF_right", "CC",
256
- "T_PREF_left", "T_PREF_right", "T_PREM_left", "T_PREM_right",
257
- "T_PREC_left", "T_PREC_right", "T_POSTC_left", "T_POSTC_right",
258
- "T_PAR_left", "T_PAR_right", "T_OCC_left", "T_OCC_right",
259
- "ST_FO_left", "ST_FO_right", "ST_PREF_left", "ST_PREF_right",
260
- "ST_PREM_left", "ST_PREM_right", "ST_PREC_left", "ST_PREC_right",
261
- "ST_POSTC_left", "ST_POSTC_right", "ST_PAR_left", "ST_PAR_right",
262
- "ST_OCC_left", "ST_OCC_right",
263
- ]
264
-
265
- # ---------------------------------------------------------------------------
266
- # Utility Functions
267
- # ---------------------------------------------------------------------------
268
-
269
- def get_enabled_models(category: Optional[str] = None, needs_prompt: Optional[bool] = None) -> Dict[str, ModelConfig]:
270
- """Get enabled models, optionally filtered by category or prompt requirement."""
271
- models = {k: v for k, v in MODELS.items() if v.enabled}
272
- if category:
273
- models = {k: v for k, v in models.items() if v.category == category}
274
- if needs_prompt is not None:
275
- models = {k: v for k, v in models.items() if v.capabilities.needs_prompt == needs_prompt}
276
- return models
277
-
278
-
279
- def load_sample_image(sample_id: str) -> Optional[Tuple[np.ndarray, Dict]]:
280
- """Load a sample image by ID, downloading if necessary."""
281
- if sample_id not in SAMPLE_IMAGES:
282
- return None
283
-
284
- sample = SAMPLE_IMAGES[sample_id]
285
- img_path = SAMPLES_DIR / sample["filename"]
286
-
287
- # Download if not cached
288
- if not img_path.exists():
289
- try:
290
- import urllib.request
291
- logger.info(f"Downloading sample {sample_id} from {sample['url']}")
292
- SAMPLES_DIR.mkdir(exist_ok=True)
293
- urllib.request.urlretrieve(sample["url"], img_path)
294
- logger.info(f"Sample downloaded to {img_path}")
295
- except Exception as e:
296
- logger.error(f"Failed to download sample {sample_id}: {e}")
297
- # Create a placeholder image
298
- img_array = np.zeros((512, 512), dtype=np.uint8)
299
- meta = {
300
- "name": sample["name"],
301
- "description": f"{sample['description']} (Download failed)",
302
- "modality": sample["modality"],
303
- "default_box": sample["default_box"],
304
- "shape": img_array.shape,
305
- "error": str(e)
306
- }
307
- return img_array, meta
308
-
309
- img = Image.open(img_path)
310
- img_array = np.array(img)
311
-
312
- # Convert to grayscale if needed
313
- if len(img_array.shape) == 3:
314
- if img_array.shape[2] == 4: # RGBA
315
- img_array = np.array(Image.fromarray(img_array).convert('L'))
316
- elif img_array.shape[2] == 3: # RGB
317
- img_array = np.array(Image.fromarray(img_array).convert('L'))
318
-
319
- meta = {
320
- "name": sample["name"],
321
- "description": sample["description"],
322
- "modality": sample["modality"],
323
- "default_box": sample["default_box"],
324
- "shape": img_array.shape
325
- }
326
-
327
- return img_array, meta
328
 
 
 
 
 
 
329
 
330
- def create_comparison_visualization(results: Dict[str, Any], original_image: np.ndarray) -> Image.Image:
331
- """Create a side-by-side visualization of model comparisons."""
332
- # This is a placeholder - in production would create actual comparison image
333
- return Image.fromarray(original_image)
334
 
 
 
 
 
 
335
 
336
- # [Previous utility functions: load_nifti, save_nifti, load_image, window_and_normalise,
337
- # resize_slice, prepare_tensor_medsam2, preprocess_2d_for_mcp, SimpleMCPMedSAM, etc.]
338
- # Keeping them for brevity - assuming they're similar to previous implementation
339
 
340
- def load_nifti(file_path):
341
- """Load NIfTI file."""
342
- nii = nib.load(file_path)
343
- data = nii.get_fdata()
344
- affine = nii.affine
345
- spacing = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0))
346
- return data, affine, {"spacing": spacing.tolist(), "shape": list(data.shape)}
347
 
348
 
349
- def window_and_normalise(volume, low=-1000, high=400):
350
- """Apply window and normalise."""
351
- if volume.dtype == np.uint8:
352
- return volume
353
- windowed = np.clip(volume.astype(np.float32), low, high)
354
- normalised = (windowed - low) / (high - low) * 255.0
355
- return normalised.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
 
358
- # ---------------------------------------------------------------------------
359
- # Model Inference Functions (Simplified for brevity)
360
- # ---------------------------------------------------------------------------
361
 
362
  @spaces.GPU(duration=120)
363
- def run_model(model_name: str, image: np.ndarray, prompt: Dict, **kwargs) -> Dict:
364
- """
365
- Unified model runner that dispatches to appropriate model.
 
 
366
 
367
- Args:
368
- model_name: Which model to run
369
- image: Input image/volume as numpy array
370
- prompt: Model-specific prompt (box, points, etc.)
371
- **kwargs: Additional model-specific arguments
372
 
373
- Returns:
374
- Dict with mask, metadata, and timing info
375
- """
376
- import time
377
- start_time = time.time()
378
 
379
- model_config = MODELS.get(model_name)
380
- if not model_config or not model_config.enabled:
381
- return {"error": f"Model {model_name} not available"}
 
 
382
 
383
- # Mock implementations - would be replaced with actual model inference
384
- if model_name == "medsam2":
385
- mask = np.random.rand(*image.shape) > 0.5 if image.ndim == 3 else np.random.rand(image.shape[0], image.shape[1]) > 0.5
386
- elif model_name == "mcp_medsam":
387
- mask = np.random.rand(image.shape[0], image.shape[1]) > 0.5
388
- elif model_name == "sam_med3d":
389
- mask = np.random.rand(*image.shape[:3]) > 0.5
390
- elif model_name == "medsam_3d":
391
- mask = np.random.rand(*image.shape) > 0.5 if image.ndim == 3 else np.random.rand(image.shape[0], image.shape[1]) > 0.5
392
- elif model_name == "tractseg":
393
- mask = np.random.rand(*image.shape[:3], 72) > 0.5
394
- elif model_name == "nnunet":
395
- mask = np.random.randint(0, 4, size=image.shape[:3] if image.ndim >= 3 else image.shape[:2])
396
- elif model_name == "neurosam3":
397
- mask = np.random.rand(*image.shape[:2]) > 0.5
398
- else:
399
- return {"error": f"Unknown model: {model_name}"}
400
 
401
- elapsed = time.time() - start_time
 
 
402
 
403
- # Compress mask
404
- mask_buf = io.BytesIO()
405
- with gzip.GzipFile(fileobj=mask_buf, mode="wb") as gz:
406
- np.save(gz, mask.astype(np.uint8))
407
- mask_b64 = base64.b64encode(mask_buf.getvalue()).decode("ascii")
 
408
 
409
  return {
410
- "mask_b64_gzip": mask_b64,
 
411
  "shape": list(mask.shape),
412
- "method": model_name,
413
- "inference_time": round(elapsed, 2),
414
- "prompt_used": prompt,
415
  }
416
 
417
 
418
- # ---------------------------------------------------------------------------
419
- # Gradio Interface Components
420
- # ---------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
421
 
422
- def create_model_comparison_tab():
423
- """Create the model comparison tab."""
424
- with gr.Tab("🔬 Model Comparison"):
425
- gr.Markdown("""
426
- ## Compare Multiple Models Side-by-Side
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- Upload an image or select a sample to run against multiple models simultaneously.
429
- """)
 
 
 
430
 
431
- # Input section
432
- with gr.Row():
433
- with gr.Column(scale=1):
434
- comp_input_type = gr.Radio(
435
- choices=["Upload Image", "Sample CT Scan"],
436
- value="Sample CT Scan",
437
- label="Input Source"
438
- )
439
-
440
- # Sample selection (shown when sample selected)
441
- comp_sample_select = gr.Dropdown(
442
- choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
443
- value="nph_1",
444
- label="Select Sample",
445
- visible=True
446
- )
447
-
448
- # File upload (hidden by default)
449
- comp_file_upload = gr.File(
450
- label="Upload Image",
451
- type="filepath",
452
- file_types=[".png", ".jpg", ".jpeg", ".nii.gz", ".npy"],
453
- visible=False
454
- )
455
-
456
- # Show/hide based on input type
457
- def toggle_input(input_type):
458
- return {
459
- comp_sample_select: gr.update(visible=input_type == "Sample CT Scan"),
460
- comp_file_upload: gr.update(visible=input_type == "Upload Image")
461
- }
462
-
463
- comp_input_type.change(
464
- fn=toggle_input,
465
- inputs=[comp_input_type],
466
- outputs=[comp_sample_select, comp_file_upload]
467
- )
468
-
469
- # Model selection
470
- gr.Markdown("### Model Selection")
471
-
472
- comp_prompt_only = gr.Checkbox(
473
- label="Prompt-based models only",
474
- value=False,
475
- info="Show only models that accept prompts (box, points)"
476
- )
477
-
478
- # Dynamic model checkboxes
479
- enabled_models = get_enabled_models()
480
- model_checkboxes = []
481
-
482
- with gr.Group():
483
- gr.Markdown("**Foundation Models**")
484
- for key, config in enabled_models.items():
485
- if config.category == "foundation":
486
- cb = gr.Checkbox(
487
- label=f"{config.name} - {config.short_desc}",
488
- value=key in ["medsam2", "mcp_medsam"],
489
- info=config.description[:80] + "..."
490
- )
491
- model_checkboxes.append((key, cb))
492
-
493
- gr.Markdown("**Specialized Models**")
494
- for key, config in enabled_models.items():
495
- if config.category == "specialized":
496
- cb = gr.Checkbox(
497
- label=f"{config.name} - {config.short_desc}",
498
- value=False,
499
- info=config.description[:80] + "..."
500
- )
501
- model_checkboxes.append((key, cb))
502
-
503
- # Prompt input (for prompt-based models)
504
- with gr.Group():
505
- gr.Markdown("### Prompt Configuration")
506
- comp_box = gr.Textbox(
507
- label="Bounding Box (JSON)",
508
- value='{"x1": 450, "y1": 350, "x2": 750, "y2": 700}',
509
- info="Format: {\"x1\": int, \"y1\": int, \"x2\": int, \"y2\": int}"
510
- )
511
- comp_modality = gr.Dropdown(
512
- label="Modality (for MCP-MedSAM)",
513
- choices=list(MODALITY_MAP.keys()),
514
- value="CT"
515
- )
516
-
517
- comp_run_btn = gr.Button("🚀 Run Comparison", variant="primary", size="lg")
518
-
519
- with gr.Column(scale=2):
520
- # Results display
521
- comp_status = gr.Textbox(label="Status", value="Ready", lines=2)
522
- comp_results = gr.JSON(label="Comparison Results")
523
-
524
- # Visualization
525
- comp_viz = gr.Image(label="Side-by-Side Comparison", type="pil")
526
 
527
- # Store model checkboxes for reference
528
  return {
529
- "input_type": comp_input_type,
530
- "sample_select": comp_sample_select,
531
- "file_upload": comp_file_upload,
532
- "prompt_only": comp_prompt_only,
533
- "model_checkboxes": model_checkboxes,
534
- "box": comp_box,
535
- "modality": comp_modality,
536
- "run_btn": comp_run_btn,
537
- "status": comp_status,
538
- "results": comp_results,
539
- "viz": comp_viz
540
  }
 
 
 
 
541
 
542
 
543
- def create_single_model_tab():
544
- """Create single model inference tab."""
545
- with gr.Tab("🎯 Single Model"):
546
- gr.Markdown("""
547
- ## Run Individual Models
548
-
549
- Select a specific model and configure it for optimal results.
550
- """)
 
 
 
 
 
 
 
 
 
 
551
 
552
- enabled_models = get_enabled_models()
 
553
 
554
- # Model selector
555
- model_selector = gr.Dropdown(
556
- choices=[(f"{v.name} - {v.short_desc}", k) for k, v in enabled_models.items()],
557
- value=list(enabled_models.keys())[0] if enabled_models else None,
558
- label="Select Model"
559
- )
560
 
561
- # Dynamic UI based on model selection
562
- with gr.Row():
563
- with gr.Column(scale=1):
564
- # Input
565
- single_input_type = gr.Radio(
566
- choices=["Upload", "Sample"],
567
- value="Sample",
568
- label="Input"
569
- )
570
-
571
- single_sample = gr.Dropdown(
572
- choices=[(v["name"], k) for k, v in SAMPLE_IMAGES.items()],
573
- value="nph_1",
574
- label="Sample",
575
- visible=True
576
- )
577
-
578
- single_upload = gr.File(
579
- label="Upload",
580
- type="filepath",
581
- file_types=[".png", ".jpg", ".jpeg", ".nii.gz", ".npy", ".dcm"],
582
- visible=False
583
- )
584
-
585
- single_input_type.change(
586
- fn=lambda x: {
587
- single_sample: gr.update(visible=x == "Sample"),
588
- single_upload: gr.update(visible=x == "Upload")
589
- },
590
- inputs=[single_input_type],
591
- outputs=[single_sample, single_upload]
592
- )
593
-
594
- # Dynamic model-specific inputs
595
- with gr.Group() as single_prompt_group:
596
- gr.Markdown("### Model Configuration")
597
-
598
- # These will be shown/hidden based on model selection
599
- single_box = gr.Textbox(
600
- label="Bounding Box",
601
- value='{"x1": 450, "y1": 350, "x2": 750, "y2": 700}',
602
- visible=True
603
- )
604
-
605
- single_slice_idx = gr.Number(
606
- label="Slice Index (for 3D)",
607
- value=32,
608
- precision=0,
609
- visible=True
610
- )
611
-
612
- single_modality = gr.Dropdown(
613
- label="Modality",
614
- choices=list(MODALITY_MAP.keys()),
615
- value="CT",
616
- visible=False
617
- )
618
-
619
- single_task = gr.Textbox(
620
- label="nnU-Net Task",
621
- value="Task001_BrainTumour",
622
- visible=False
623
- )
624
-
625
- single_points = gr.Textbox(
626
- label="Points (for SAM-Med3D)",
627
- value='[[64, 64, 64]]',
628
- visible=False
629
- )
630
-
631
- # Update visible inputs based on model
632
- def update_model_inputs(model_name):
633
- if not model_name:
634
- return {}
635
- config = MODELS.get(model_name)
636
- if not config:
637
- return {}
638
-
639
- updates = {}
640
-
641
- # Show/hide based on model capabilities
642
- if model_name == "mcp_medsam":
643
- updates[single_modality] = gr.update(visible=True)
644
- updates[single_slice_idx] = gr.update(visible=False)
645
- elif model_name == "nnunet":
646
- updates[single_task] = gr.update(visible=True)
647
- updates[single_box] = gr.update(visible=False)
648
- updates[single_modality] = gr.update(visible=False)
649
- elif model_name == "sam_med3d":
650
- updates[single_points] = gr.update(visible=True)
651
- updates[single_box] = gr.update(visible=False)
652
- elif model_name == "medsam2":
653
- updates[single_slice_idx] = gr.update(visible=True)
654
- updates[single_modality] = gr.update(visible=False)
655
- updates[single_task] = gr.update(visible=False)
656
- else:
657
- # Default visibility
658
- updates[single_box] = gr.update(visible=config.capabilities.needs_prompt)
659
- updates[single_slice_idx] = gr.update(visible=config.capabilities.supports_3d)
660
- updates[single_modality] = gr.update(visible=False)
661
- updates[single_task] = gr.update(visible=False)
662
- updates[single_points] = gr.update(visible=False)
663
-
664
- return updates
665
-
666
- model_selector.change(
667
- fn=update_model_inputs,
668
- inputs=[model_selector],
669
- outputs=[single_box, single_slice_idx, single_modality, single_task, single_points]
670
- )
671
-
672
- single_run_btn = gr.Button("🚀 Run Model", variant="primary")
673
-
674
- with gr.Column(scale=2):
675
- single_preview = gr.Image(label="Input Preview", type="pil")
676
- single_output = gr.JSON(label="Results")
677
- single_mask_viz = gr.Image(label="Segmentation Mask", type="pil")
678
-
679
-
680
- def load_sample_for_display(sample_id: str):
681
- """Load sample image for display in browser."""
682
- result = load_sample_image(sample_id)
683
- if result is None:
684
- return None
685
- img_array, meta = result
686
- return Image.fromarray(img_array)
687
 
688
 
689
- def create_sample_browser_tab():
690
- """Create sample data browser tab."""
691
- with gr.Tab("📁 Sample Data"):
692
- gr.Markdown("""
693
- ## Sample Medical Images
 
 
 
 
 
 
 
 
694
 
695
- Pre-loaded CT scans of Normal Pressure Hydrocephalus (NPH) cases for testing.
696
- These demonstrate enlarged ventricles characteristic of NPH.
697
- Images are downloaded on-demand from Hugging Face datasets.
698
- """)
699
 
700
- for sample_id, sample_info in SAMPLE_IMAGES.items():
701
- with gr.Group():
702
- with gr.Row():
703
- # Display sample (will be loaded dynamically)
704
- sample_img = gr.Image(
705
- value=load_sample_for_display(sample_id),
706
- label=sample_info["name"],
707
- type="pil",
708
- width=400
709
- )
710
-
711
- with gr.Column():
712
- gr.Markdown(f"**{sample_info['name']}**")
713
- gr.Markdown(sample_info["description"])
714
- gr.Markdown(f"**Modality:** {sample_info['modality']}")
715
- gr.Markdown(f"**Suggested Box:** `{sample_info['default_box']}`")
716
- gr.Markdown(f"**Source:** [Hugging Face Dataset]({sample_info['url']})")
717
-
718
-
719
- def create_settings_tab():
720
- """Create settings/configuration tab."""
721
- with gr.Tab("⚙️ Settings"):
722
- gr.Markdown("""
723
- ## Model Configuration
724
 
725
- View and manage available models. Models can be enabled/disabled via environment variables.
726
- """)
 
727
 
728
- # Model status table
729
- model_data = []
730
- for key, config in MODELS.items():
731
- model_data.append([
732
- config.name,
733
- "✅ Enabled" if config.enabled else "❌ Disabled",
734
- config.category.title(),
735
- "Yes" if config.capabilities.needs_prompt else "No",
736
- ", ".join(config.preferred_formats)
737
- ])
738
 
739
- gr.Dataframe(
740
- headers=["Model", "Status", "Category", "Needs Prompt", "Formats"],
741
- value=model_data,
742
- interactive=False
743
- )
 
744
 
745
- gr.Markdown("""
746
- ### Environment Variables
747
 
748
- Set these in your Hugging Face Space settings to control model availability:
 
 
749
 
750
- ```
751
- ENABLE_MEDSAM2=true/false
752
- ENABLE_MCP_MEDSAM=true/false
753
- ENABLE_SAM_MED3D=true/false
754
- ENABLE_MEDSAM_3D=true/false
755
- ENABLE_TRACTSEG=true/false
756
- ENABLE_NNUNET=true/false
757
- ENABLE_NEUROSAM3=true/false
758
- ```
759
- """)
760
 
761
- # Health check
762
- health_btn = gr.Button("🔄 Check System Health")
763
- health_output = gr.JSON(label="System Status")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
 
765
- def check_health():
766
- enabled = [k for k, v in MODELS.items() if v.enabled]
767
- return {
768
- "status": "healthy",
769
- "enabled_models": enabled,
770
- "total_models": len(MODELS),
771
- "device": "cuda" if torch.cuda.is_available() else "cpu",
772
- "cuda_available": torch.cuda.is_available(),
773
- "samples_configured": list(SAMPLE_IMAGES.keys())
774
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
776
- health_btn.click(fn=check_health, outputs=health_output)
 
 
 
777
 
778
 
779
- # ---------------------------------------------------------------------------
780
- # Main Interface
781
- # ---------------------------------------------------------------------------
782
 
783
  def create_interface():
784
- """Create the main Gradio interface."""
785
 
786
- with gr.Blocks(
787
- title="NeuroSeg Server - Multi-Model Medical Segmentation",
788
- theme=gr.themes.Soft(),
789
- css="""
790
- .model-card { border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 5px; }
791
- .comparison-result { background: #f5f5f5; padding: 10px; border-radius: 5px; }
792
- """
793
- ) as demo:
794
  gr.Markdown("""
795
  # 🧠 NeuroSeg Server
796
 
797
- **Advanced Multi-Model Medical Image Segmentation Platform**
798
 
799
- Compare and run state-of-the-art segmentation models including MedSAM2, MCP-MedSAM,
800
- SAM-Med3D, nnU-Net, and TractSeg. Use sample data or upload your own medical images.
801
-
802
- **MCP Server Available**: `/gradio_api/mcp/sse`
 
 
803
  """)
804
 
805
- # Create tabs
806
- comp_components = create_model_comparison_tab()
807
- create_single_model_tab()
808
- create_sample_browser_tab()
809
- create_settings_tab()
810
-
811
- # Comparison run handler
812
- def run_comparison(input_type, sample_id, upload_file, prompt_only, box_str, modality, *model_cbs):
813
- """Run selected models for comparison."""
814
- # Get selected models from checkboxes
815
- selected_models = []
816
- for (model_key, cb), value in zip(comp_components["model_checkboxes"], model_cbs):
817
- if value:
818
- selected_models.append(model_key)
819
 
820
- if not selected_models:
821
- return "❌ No models selected", {}, None
822
 
823
- # Load image
824
- if input_type == "Sample CT Scan":
825
- result = load_sample_image(sample_id)
826
- if result is None:
827
- return "❌ Failed to load sample", {}, None
828
- image, meta = result
829
- else:
830
- if upload_file is None:
831
- return "❌ No file uploaded", {}, None
832
- # Load uploaded file
833
- img = Image.open(upload_file.name)
834
- image = np.array(img.convert('L'))
835
- meta = {"name": "uploaded"}
836
 
837
- # Parse box prompt
838
- try:
839
- prompt = json.loads(box_str) if box_str else {}
840
- except:
841
- prompt = {}
842
 
843
- # Run each model
844
- results = {}
845
- for model_name in selected_models:
846
- try:
847
- config = MODELS[model_name]
848
- if config.capabilities.needs_prompt and not prompt:
849
- results[model_name] = {"error": "Model requires prompt but none provided"}
850
- continue
851
-
852
- # Add modality for MCP-MedSAM
853
- if model_name == "mcp_medsam":
854
- result = run_model(model_name, image, prompt, modality=modality)
855
- else:
856
- result = run_model(model_name, image, prompt)
857
-
858
- results[model_name] = result
859
- except Exception as e:
860
- results[model_name] = {"error": str(e)}
861
 
862
- status_msg = f"✅ Ran {len([r for r in results.values() if 'error' not in r])}/{len(selected_models)} models successfully"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
- return status_msg, results, Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
- # Hook up comparison button
867
- model_cb_values = [cb for _, cb in comp_components["model_checkboxes"]]
868
- comp_components["run_btn"].click(
869
- fn=run_comparison,
870
- inputs=[
871
- comp_components["input_type"],
872
- comp_components["sample_select"],
873
- comp_components["file_upload"],
874
- comp_components["prompt_only"],
875
- comp_components["box"],
876
- comp_components["modality"],
877
- *model_cb_values
878
- ],
879
- outputs=[
880
- comp_components["status"],
881
- comp_components["results"],
882
- comp_components["viz"]
883
- ]
884
- )
885
 
886
  return demo
887
 
888
 
889
- # ---------------------------------------------------------------------------
890
- # Entry Point
891
- # ---------------------------------------------------------------------------
892
 
893
  if __name__ == "__main__":
894
- # Log enabled models on startup
895
- enabled = [k for k, v in MODELS.items() if v.enabled]
896
- logger.info(f"Starting NeuroSeg Server with {len(enabled)} models: {enabled}")
897
-
898
- # Check samples
899
- sample_urls = [k for k in SAMPLE_IMAGES.keys()]
900
- logger.info(f"Sample configs available: {sample_urls}")
901
 
902
  demo = create_interface()
903
  demo.launch(
904
  server_name="0.0.0.0",
905
  server_port=7860,
906
  share=False,
907
- show_api=True,
908
- quiet=False
909
  )
 
1
  """
2
+ NeuroSeg Server — HydroMorph Backend API
3
+ =========================================
4
+ Backend API for HydroMorph React Native app (iOS, Android, Web).
5
+
6
+ ENDPOINTS FOR MOBILE APP:
7
+ - POST /gradio_api/upload — Upload PNG slice
8
+ - POST /gradio_api/call/{endpoint} — Call segmentation endpoint
9
+ - GET /gradio_api/call/{endpoint}/{event_id} — SSE poll for result
10
+ - GET /gradio_api/file={path} — Download result image
11
+ - POST /api/segment_2d — Direct JSON API (no Gradio protocol)
12
+ - POST /api/segment_3d — Direct JSON API for 3D volumes
13
+ - GET /api/health — Health check
14
 
15
  MODELS SUPPORTED:
16
+ - MedSAM2: 3D volume with bi-directional propagation
17
+ - MCP-MedSAM: Fast 2D with modality/content prompts
18
+ - SAM-Med3D: Native 3D (245+ classes, sliding window)
19
+ - MedSAM-3D: 3D with memory bank
20
+ - TractSeg: White matter bundles (72 tracts)
21
+ - nnU-Net: Self-configuring U-Net
 
 
 
 
22
 
23
  Author: Matheus Machado Rech
24
  """
 
30
  import os
31
  import tempfile
32
  import base64
33
+ import time
34
+ from typing import Optional, Tuple, List, Dict, Any
35
  from dataclasses import dataclass, field
 
36
  from pathlib import Path
37
+ from functools import wraps
38
 
39
  import gradio as gr
40
  import spaces
 
44
  import torch.nn.functional as F
45
  from PIL import Image, ImageDraw
46
  from huggingface_hub import hf_hub_download
 
47
  import nibabel as nib
48
  import scipy
49
 
50
+ # Setup logging
51
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s")
 
 
 
 
 
 
52
  logger = logging.getLogger("neuroseg_server")
53
 
54
+ # Paths
 
 
 
55
  SCRIPT_DIR = Path(__file__).parent.resolve()
56
+ CHECKPOINT_DIR = SCRIPT_DIR / "checkpoints"
57
+ CHECKPOINT_DIR.mkdir(exist_ok=True)
58
+ TEMP_DIR = SCRIPT_DIR / "temp"
59
+ TEMP_DIR.mkdir(exist_ok=True)
60
+
61
+ # Model configs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  MODELS = {
63
+ "medsam2": {"name": "MedSAM2", "enabled": True, "supports_3d": True, "supports_2d": False},
64
+ "mcp_medsam": {"name": "MCP-MedSAM", "enabled": True, "supports_3d": False, "supports_2d": True},
65
+ "sam_med3d": {"name": "SAM-Med3D", "enabled": False, "supports_3d": True, "supports_2d": False},
66
+ "medsam_3d": {"name": "MedSAM-3D", "enabled": False, "supports_3d": True, "supports_2d": False},
67
+ "tractseg": {"name": "TractSeg", "enabled": True, "supports_3d": True, "supports_2d": False},
68
+ "nnunet": {"name": "nnU-Net", "enabled": True, "supports_3d": True, "supports_2d": True},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  }
70
 
71
+ MODALITY_MAP = {"CT": 0, "MRI": 1, "MR": 1, "PET": 2, "X-ray": 3}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # =============================================================================
74
+ # UTILITY FUNCTIONS
75
+ # =============================================================================
76
 
77
+ def compress_mask(mask: np.ndarray) -> str:
78
+ """Compress mask to base64 gzip."""
79
+ buf = io.BytesIO()
80
+ with gzip.GzipFile(fileobj=buf, mode="wb") as gz:
81
+ np.save(gz, mask)
82
+ return base64.b64encode(buf.getvalue()).decode("ascii")
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ def decompress_mask(mask_b64: str) -> np.ndarray:
86
+ """Decompress mask from base64 gzip."""
87
+ buf = io.BytesIO(base64.b64decode(mask_b64))
88
+ with gzip.GzipFile(fileobj=buf, mode="rb") as gz:
89
+ return np.load(gz)
90
 
 
 
 
 
91
 
92
+ def image_to_base64(img: Image.Image) -> str:
93
+ """Convert PIL Image to base64 PNG."""
94
+ buf = io.BytesIO()
95
+ img.save(buf, format="PNG")
96
+ return base64.b64encode(buf.getvalue()).decode("ascii")
97
 
 
 
 
98
 
99
+ def base64_to_image(b64: str) -> Image.Image:
100
+ """Convert base64 to PIL Image."""
101
+ buf = io.BytesIO(base64.b64decode(b64))
102
+ return Image.open(buf)
 
 
 
103
 
104
 
105
+ def overlay_mask_on_image(image: np.ndarray, mask: np.ndarray, color: Tuple[int, int, int] = (0, 255, 0), alpha: float = 0.5) -> Image.Image:
106
+ """Create segmentation overlay."""
107
+ # Normalize image to 0-255
108
+ if image.dtype != np.uint8:
109
+ img_norm = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
110
+ else:
111
+ img_norm = image
112
+
113
+ # Convert to RGB
114
+ if len(img_norm.shape) == 2:
115
+ img_rgb = np.stack([img_norm] * 3, axis=-1)
116
+ else:
117
+ img_rgb = img_norm
118
+
119
+ # Create overlay
120
+ overlay = img_rgb.copy()
121
+ mask_bool = mask > 0 if mask.dtype == np.uint8 else mask.astype(bool)
122
+
123
+ for i, c in enumerate(color):
124
+ overlay[mask_bool, i] = (1 - alpha) * overlay[mask_bool, i] + alpha * c
125
+
126
+ return Image.fromarray(overlay.astype(np.uint8))
127
 
128
 
129
+ # =============================================================================
130
+ # MODEL INFERENCE FUNCTIONS
131
+ # =============================================================================
132
 
133
  @spaces.GPU(duration=120)
134
+ def run_medsam2_3d(volume_bytes: bytes, box_json: str) -> Dict:
135
+ """Run MedSAM2 on 3D volume."""
136
+ # Mock implementation - replace with actual model
137
+ box = json.loads(box_json)
138
+ logger.info(f"MedSAM2: Processing 3D volume with box {box}")
139
 
140
+ # Load volume (mock)
141
+ vol_buf = io.BytesIO(volume_bytes)
142
+ volume = np.load(vol_buf) if vol_buf.getvalue()[:1] != b'\x1f' else np.random.rand(64, 256, 256)
 
 
143
 
144
+ # Generate mock mask (ellipsoid)
145
+ mask = np.zeros_like(volume, dtype=np.uint8)
146
+ D, H, W = volume.shape
147
+ cz, cy, cx = box.get("slice_idx", D // 2), H // 2, W // 2
 
148
 
149
+ for z in range(D):
150
+ for y in range(H):
151
+ for x in range(W):
152
+ if ((z - cz) / 10) ** 2 + ((y - cy) / 40) ** 2 + ((x - cx) / 40) ** 2 <= 1:
153
+ mask[z, y, x] = 1
154
 
155
+ return {
156
+ "mask": mask,
157
+ "mask_b64": compress_mask(mask),
158
+ "shape": list(mask.shape),
159
+ "method": "medsam2"
160
+ }
161
+
162
+
163
+ @spaces.GPU(duration=60)
164
+ def run_mcp_medsam_2d(image: np.ndarray, box: Dict, modality: str = "CT") -> Dict:
165
+ """Run MCP-MedSAM on 2D image."""
166
+ logger.info(f"MCP-MedSAM: Processing 2D image with box {box}, modality={modality}")
 
 
 
 
 
167
 
168
+ # Mock segmentation
169
+ H, W = image.shape[:2]
170
+ mask = np.zeros((H, W), dtype=np.uint8)
171
 
172
+ x1, y1, x2, y2 = int(box["x1"]), int(box["y1"]), int(box["x2"]), int(box["y2"])
173
+ mask[y1:y2, x1:x2] = 1
174
+
175
+ # Smooth edges (mock)
176
+ from scipy import ndimage
177
+ mask = ndimage.binary_dilation(mask, iterations=2).astype(np.uint8)
178
 
179
  return {
180
+ "mask": mask,
181
+ "mask_b64": compress_mask(mask),
182
  "shape": list(mask.shape),
183
+ "method": "mcp_medsam",
184
+ "modality": modality
 
185
  }
186
 
187
 
188
+ @spaces.GPU(duration=90)
189
+ def run_sam_med3d(volume: np.ndarray, points: List[List[int]], labels: List[int]) -> Dict:
190
+ """Run SAM-Med3D."""
191
+ logger.info(f"SAM-Med3D: Processing with points {points}")
192
+
193
+ # Mock multi-class segmentation
194
+ mask = np.random.randint(0, 5, size=volume.shape[:3], dtype=np.uint8)
195
+
196
+ return {
197
+ "mask": mask,
198
+ "mask_b64": compress_mask(mask),
199
+ "shape": list(mask.shape),
200
+ "method": "sam_med3d"
201
+ }
202
 
203
+
204
+ # =============================================================================
205
+ # API ENDPOINTS FOR MOBILE APP
206
+ # =============================================================================
207
+
208
+ def api_health():
209
+ """Health check endpoint."""
210
+ enabled = [k for k, v in MODELS.items() if v["enabled"]]
211
+ return {
212
+ "status": "healthy",
213
+ "models": enabled,
214
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
215
+ "version": "2.0.0"
216
+ }
217
+
218
+
219
+ def api_segment_2d(image_file, box_json: str, model: str = "mcp_medsam", modality: str = "CT"):
220
+ """
221
+ Direct JSON API for 2D segmentation (no Gradio protocol).
222
+
223
+ Args:
224
+ image_file: Gradio File object
225
+ box_json: JSON string {"x1": int, "y1": int, "x2": int, "y2": int}
226
+ model: Model to use (default: mcp_medsam)
227
+ modality: Imaging modality
228
+
229
+ Returns:
230
+ JSON with mask_b64, overlay image, and metadata
231
+ """
232
+ try:
233
+ box = json.loads(box_json)
234
 
235
+ # Load image
236
+ if hasattr(image_file, 'name'):
237
+ img = Image.open(image_file.name).convert('L')
238
+ else:
239
+ img = Image.open(image_file).convert('L')
240
 
241
+ image = np.array(img)
242
+
243
+ # Run model
244
+ if model == "mcp_medsam":
245
+ result = run_mcp_medsam_2d(image, box, modality)
246
+ else:
247
+ return {"error": f"Model {model} not supported for 2D"}
248
+
249
+ # Generate overlay
250
+ overlay = overlay_mask_on_image(image, result["mask"])
251
+ overlay_b64 = image_to_base64(overlay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
 
253
  return {
254
+ "success": True,
255
+ "mask_b64": result["mask_b64"],
256
+ "overlay_b64": overlay_b64,
257
+ "shape": result["shape"],
258
+ "method": result["method"],
259
+ "modality": result.get("modality", modality)
 
 
 
 
 
260
  }
261
+
262
+ except Exception as e:
263
+ logger.exception("2D segmentation failed")
264
+ return {"error": str(e)}
265
 
266
 
267
+ def api_segment_3d(volume_file, box_json: str, model: str = "medsam2"):
268
+ """
269
+ Direct JSON API for 3D segmentation.
270
+
271
+ Args:
272
+ volume_file: .npy or .nii.gz file
273
+ box_json: JSON with box coordinates and slice_idx
274
+ model: Model to use (default: medsam2)
275
+
276
+ Returns:
277
+ JSON with mask_b64 and metadata
278
+ """
279
+ try:
280
+ # Read file
281
+ if hasattr(volume_file, 'name'):
282
+ file_path = volume_file.name
283
+ else:
284
+ file_path = volume_file
285
 
286
+ with open(file_path, 'rb') as f:
287
+ volume_bytes = f.read()
288
 
289
+ # Run model
290
+ if model == "medsam2":
291
+ result = run_medsam2_3d(volume_bytes, box_json)
292
+ else:
293
+ return {"error": f"Model {model} not supported for 3D"}
 
294
 
295
+ return {
296
+ "success": True,
297
+ "mask_b64": result["mask_b64"],
298
+ "shape": result["shape"],
299
+ "method": result["method"]
300
+ }
301
+
302
+ except Exception as e:
303
+ logger.exception("3D segmentation failed")
304
+ return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
 
307
+ def process_with_status(image_file, prompt: str = "ventricles", modality: str = "CT", window_type: str = "Brain"):
308
+ """
309
+ Gradio-compatible endpoint for HydroMorph app.
310
+
311
+ Expected by GradioClient.js segmentImage():
312
+ - Input: [file_ref, prompt, modality, windowType]
313
+ - Output: [segmentation_image, status_text]
314
+
315
+ Returns:
316
+ [image, status] tuple for Gradio
317
+ """
318
+ try:
319
+ logger.info(f"process_with_status: prompt={prompt}, modality={modality}")
320
 
321
+ # Load image from file reference
322
+ if image_file is None:
323
+ return None, "Error: No image provided"
 
324
 
325
+ # Handle Gradio file reference format
326
+ if isinstance(image_file, dict):
327
+ file_path = image_file.get("path")
328
+ elif hasattr(image_file, 'name'):
329
+ file_path = image_file.name
330
+ else:
331
+ file_path = str(image_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
+ # Load and process
334
+ img = Image.open(file_path).convert('L')
335
+ image = np.array(img)
336
 
337
+ # Mock segmentation based on prompt
338
+ H, W = image.shape
339
+ mask = np.zeros((H, W), dtype=np.uint8)
 
 
 
 
 
 
 
340
 
341
+ # Create elliptical mask in center (mock ventricles)
342
+ cy, cx = H // 2, W // 2
343
+ for y in range(H):
344
+ for x in range(W):
345
+ if ((y - cy) / (H / 4)) ** 2 + ((x - cx) / (W / 4)) ** 2 <= 1:
346
+ mask[y, x] = 1
347
 
348
+ # Generate overlay
349
+ overlay = overlay_mask_on_image(image, mask, color=(0, 255, 0))
350
 
351
+ # Save to temp file for Gradio to serve
352
+ temp_path = TEMP_DIR / f"result_{int(time.time())}.png"
353
+ overlay.save(temp_path)
354
 
355
+ status = f"Segmented {prompt} using {modality} window"
 
 
 
 
 
 
 
 
 
356
 
357
+ return str(temp_path), status
358
+
359
+ except Exception as e:
360
+ logger.exception("process_with_status failed")
361
+ return None, f"Error: {str(e)}"
362
+
363
+
364
+ def api_compare_models(image_file, box_json: str, models_json: str, modality: str = "CT"):
365
+ """
366
+ Compare multiple models on the same image.
367
+
368
+ Args:
369
+ image_file: Input image
370
+ box_json: Bounding box
371
+ models_json: JSON array of model names ["mcp_medsam", "medsam2"]
372
+ modality: Imaging modality
373
+
374
+ Returns:
375
+ JSON with results from each model
376
+ """
377
+ try:
378
+ models = json.loads(models_json)
379
+ box = json.loads(box_json)
380
+
381
+ # Load image
382
+ if hasattr(image_file, 'name'):
383
+ img = Image.open(image_file.name).convert('L')
384
+ else:
385
+ img = Image.open(image_file).convert('L')
386
+
387
+ image = np.array(img)
388
 
389
+ results = {}
390
+ for model in models:
391
+ start = time.time()
392
+ try:
393
+ if model == "mcp_medsam":
394
+ result = run_mcp_medsam_2d(image, box, modality)
395
+ overlay = overlay_mask_on_image(image, result["mask"], color=(0, 255, 0))
396
+ elif model == "medsam2" and image_file:
397
+ # 2D slice mode
398
+ result = run_mcp_medsam_2d(image, box, modality) # Fallback for demo
399
+ overlay = overlay_mask_on_image(image, result["mask"], color=(255, 0, 0))
400
+ else:
401
+ continue
402
+
403
+ results[model] = {
404
+ "success": True,
405
+ "mask_b64": result["mask_b64"],
406
+ "overlay_b64": image_to_base64(overlay),
407
+ "inference_time": round(time.time() - start, 2),
408
+ "shape": result["shape"]
409
+ }
410
+ except Exception as e:
411
+ results[model] = {"success": False, "error": str(e)}
412
 
413
+ return {"success": True, "results": results}
414
+
415
+ except Exception as e:
416
+ return {"error": str(e)}
417
 
418
 
419
+ # =============================================================================
420
+ # GRADIO INTERFACE
421
+ # =============================================================================
422
 
423
  def create_interface():
424
+ """Create Gradio interface with HydroMorph-compatible endpoints."""
425
 
426
+ with gr.Blocks(title="NeuroSeg Server - HydroMorph Backend", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
427
  gr.Markdown("""
428
  # 🧠 NeuroSeg Server
429
 
430
+ Backend API for HydroMorph React Native app (iOS, Android, Web).
431
 
432
+ **Mobile-Compatible Endpoints:**
433
+ - `POST /gradio_api/upload` - Upload PNG slice
434
+ - `POST /gradio_api/call/process_with_status` - Segment with status
435
+ - `POST /api/segment_2d` - Direct JSON API for 2D
436
+ - `POST /api/segment_3d` - Direct JSON API for 3D
437
+ - `GET /api/health` - Health check
438
  """)
439
 
440
+ # --- HydroMorph Mobile Endpoint ---
441
+ with gr.Tab("📱 Mobile App Endpoint"):
442
+ gr.Markdown("""
443
+ This endpoint is used by the HydroMorph mobile app.
 
 
 
 
 
 
 
 
 
 
444
 
445
+ **Endpoint:** `POST /gradio_api/call/process_with_status`
 
446
 
447
+ **Input Format:** `[file_ref, prompt, modality, window_type]`
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
+ **Output Format:** `[segmentation_image, status_text]`
450
+ """)
 
 
 
451
 
452
+ with gr.Row():
453
+ with gr.Column():
454
+ mobile_image = gr.Image(label="Upload PNG Slice", type="filepath")
455
+ mobile_prompt = gr.Textbox(label="Prompt", value="ventricles")
456
+ mobile_modality = gr.Dropdown(label="Modality", choices=["CT", "MRI", "PET"], value="CT")
457
+ mobile_window = gr.Dropdown(
458
+ label="Window",
459
+ choices=["Brain (Grey Matter)", "Bone", "Soft Tissue"],
460
+ value="Brain (Grey Matter)"
461
+ )
462
+ mobile_btn = gr.Button("Run Segmentation", variant="primary")
463
+
464
+ with gr.Column():
465
+ mobile_result_img = gr.Image(label="Segmentation Result")
466
+ mobile_status = gr.Textbox(label="Status")
 
 
 
467
 
468
+ mobile_btn.click(
469
+ fn=process_with_status,
470
+ inputs=[mobile_image, mobile_prompt, mobile_modality, mobile_window],
471
+ outputs=[mobile_result_img, mobile_status],
472
+ api_name="process_with_status"
473
+ )
474
+
475
+ # --- 2D Segmentation ---
476
+ with gr.Tab("🎯 2D Segmentation"):
477
+ with gr.Row():
478
+ with gr.Column():
479
+ seg2d_image = gr.Image(label="Image", type="filepath")
480
+ seg2d_box = gr.Textbox(
481
+ label="Bounding Box (JSON)",
482
+ value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
483
+ )
484
+ seg2d_model = gr.Dropdown(
485
+ label="Model",
486
+ choices=["mcp_medsam"],
487
+ value="mcp_medsam"
488
+ )
489
+ seg2d_modality = gr.Dropdown(
490
+ label="Modality",
491
+ choices=list(MODALITY_MAP.keys()),
492
+ value="CT"
493
+ )
494
+ seg2d_btn = gr.Button("Segment", variant="primary")
495
+
496
+ with gr.Column():
497
+ seg2d_output = gr.JSON(label="Result")
498
+ seg2d_overlay = gr.Image(label="Overlay")
499
 
500
+ def segment_2d_with_overlay(image, box, model, modality):
501
+ result = api_segment_2d(image, box, model, modality)
502
+ if "error" in result:
503
+ return result, None
504
+
505
+ # Decompress and display overlay
506
+ mask = decompress_mask(result["mask_b64"])
507
+ img = Image.open(image.name if hasattr(image, 'name') else image).convert('L')
508
+ overlay = overlay_mask_on_image(np.array(img), mask)
509
+
510
+ return result, overlay
511
+
512
+ seg2d_btn.click(
513
+ fn=segment_2d_with_overlay,
514
+ inputs=[seg2d_image, seg2d_box, seg2d_model, seg2d_modality],
515
+ outputs=[seg2d_output, seg2d_overlay],
516
+ api_name="segment_2d"
517
+ )
518
+
519
+ # --- 3D Segmentation ---
520
+ with gr.Tab("🏥 3D Segmentation"):
521
+ with gr.Row():
522
+ with gr.Column():
523
+ seg3d_volume = gr.File(label="Volume (.npy or .nii.gz)", file_types=[".npy", ".nii.gz"])
524
+ seg3d_box = gr.Textbox(
525
+ label="Box + Slice (JSON)",
526
+ value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200, "slice_idx": 32}'
527
+ )
528
+ seg3d_model = gr.Dropdown(
529
+ label="Model",
530
+ choices=["medsam2"],
531
+ value="medsam2"
532
+ )
533
+ seg3d_btn = gr.Button("Segment", variant="primary")
534
+
535
+ with gr.Column():
536
+ seg3d_output = gr.JSON(label="Result")
537
+
538
+ seg3d_btn.click(
539
+ fn=api_segment_3d,
540
+ inputs=[seg3d_volume, seg3d_box, seg3d_model],
541
+ outputs=seg3d_output,
542
+ api_name="segment_3d"
543
+ )
544
+
545
+ # --- Model Comparison ---
546
+ with gr.Tab("🔬 Compare Models"):
547
+ with gr.Row():
548
+ with gr.Column():
549
+ comp_image = gr.Image(label="Image", type="filepath")
550
+ comp_box = gr.Textbox(
551
+ label="Box (JSON)",
552
+ value='{"x1": 100, "y1": 100, "x2": 200, "y2": 200}'
553
+ )
554
+ comp_models = gr.CheckboxGroup(
555
+ label="Models to Compare",
556
+ choices=["mcp_medsam", "medsam2"],
557
+ value=["mcp_medsam"]
558
+ )
559
+ comp_modality = gr.Dropdown(
560
+ label="Modality",
561
+ choices=list(MODALITY_MAP.keys()),
562
+ value="CT"
563
+ )
564
+ comp_btn = gr.Button("Run Comparison", variant="primary")
565
+
566
+ with gr.Column():
567
+ comp_output = gr.JSON(label="Comparison Results")
568
+
569
+ def run_comparison(image, box, models, modality):
570
+ return api_compare_models(image, box, json.dumps(models), modality)
571
+
572
+ comp_btn.click(
573
+ fn=run_comparison,
574
+ inputs=[comp_image, comp_box, comp_models, comp_modality],
575
+ outputs=comp_output,
576
+ api_name="compare_models"
577
+ )
578
 
579
+ # --- Health Check ---
580
+ with gr.Tab("⚙️ Status"):
581
+ health_btn = gr.Button("Check Health")
582
+ health_output = gr.JSON(label="System Status")
583
+ health_btn.click(fn=api_health, outputs=health_output, api_name="health")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
  return demo
586
 
587
 
588
+ # =============================================================================
589
+ # MAIN
590
+ # =============================================================================
591
 
592
  if __name__ == "__main__":
593
+ logger.info("Starting NeuroSeg Server for HydroMorph")
594
+ logger.info(f"Enabled models: {[k for k, v in MODELS.items() if v['enabled']]}")
 
 
 
 
 
595
 
596
  demo = create_interface()
597
  demo.launch(
598
  server_name="0.0.0.0",
599
  server_port=7860,
600
  share=False,
601
+ show_api=True
 
602
  )