maregu2023 commited on
Commit
f05ab78
·
1 Parent(s): 458f770

Add Phase 11: FastAPI backend for slicer-grade frontend support

Browse files
Files changed (2) hide show
  1. seg_app/backend/__init__.py +10 -0
  2. seg_app/backend/api.py +586 -0
seg_app/backend/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Backend API module for seg_app.
3
+
4
+ Provides a FastAPI-based REST API for slicer-grade frontend support.
5
+ Wraps the existing orchestrator without modification.
6
+ """
7
+
8
+ from seg_app.backend.api import create_api_app
9
+
10
+ __all__ = ["create_api_app"]
seg_app/backend/api.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI backend for medical image segmentation.
3
+
4
+ This module provides a REST API that wraps the existing orchestrator
5
+ to support slicer-grade frontends (VTK.js, MONAI-style viewers).
6
+
7
+ Design principles:
8
+ - No modification to existing orchestrator or model code
9
+ - Importable without side effects (no model loading on import)
10
+ - Simple in-memory state for single-user research use
11
+ - Meaningful HTTP errors without exposing stack traces
12
+
13
+ Usage:
14
+ from seg_app.backend.api import create_api_app
15
+ app = create_api_app()
16
+ # Run with: uvicorn seg_app.backend.api:app --reload
17
+ """
18
+
19
+ import io
20
+ import logging
21
+ import tempfile
22
+ import uuid
23
+ from dataclasses import dataclass, field
24
+ from enum import Enum
25
+ from pathlib import Path
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+ import numpy as np
29
+ from fastapi import FastAPI, File, HTTPException, UploadFile
30
+ from fastapi.middleware.cors import CORSMiddleware
31
+ from fastapi.responses import Response
32
+ from pydantic import BaseModel, Field
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # =============================================================================
38
+ # Pydantic Models for API Request/Response
39
+ # =============================================================================
40
+
41
+ class PromptTypeEnum(str, Enum):
42
+ """Prompt types matching orchestrator.PromptType."""
43
+ point_positive = "point_positive"
44
+ point_negative = "point_negative"
45
+ bounding_box = "bounding_box"
46
+
47
+
48
+ class PromptModel(BaseModel):
49
+ """A single prompt for interactive segmentation."""
50
+ prompt_type: PromptTypeEnum
51
+ coordinates: List[int] = Field(
52
+ ...,
53
+ description="For points: [d, h, w]; for boxes: [d1, h1, w1, d2, h2, w2]"
54
+ )
55
+
56
+
57
+ class PromptsModel(BaseModel):
58
+ """Collection of prompts for segmentation refinement."""
59
+ items: List[PromptModel] = Field(default_factory=list)
60
+
61
+
62
+ class VolumeUploadResponse(BaseModel):
63
+ """Response after successful volume upload."""
64
+ volume_id: str
65
+ shape: Tuple[int, int, int]
66
+ spacing: Tuple[float, float, float]
67
+ message: str = "Volume uploaded successfully"
68
+
69
+
70
+ class SegmentationRequest(BaseModel):
71
+ """Request to run segmentation on an uploaded volume."""
72
+ volume_id: str
73
+ model_id: str = Field(
74
+ default="medical-sam-3d",
75
+ description="Model ID: 'medical-sam-3d' or 'unet3d-brain-tumor'"
76
+ )
77
+ prompts: Optional[PromptsModel] = None
78
+
79
+
80
+ class RefineRequest(BaseModel):
81
+ """Request to refine an existing segmentation with prompts."""
82
+ volume_id: str
83
+ prompts: PromptsModel
84
+
85
+
86
+ class SegmentationResponse(BaseModel):
87
+ """Response after segmentation completes."""
88
+ volume_id: str
89
+ model_id: str
90
+ task_name: str
91
+ mask_shape: Tuple[int, int, int]
92
+ status: str = "completed"
93
+ metadata: Dict[str, Any] = Field(default_factory=dict)
94
+
95
+
96
+ class ErrorResponse(BaseModel):
97
+ """Standard error response."""
98
+ detail: str
99
+ error_code: str
100
+
101
+
102
+ class AvailableModelsResponse(BaseModel):
103
+ """List of available segmentation models."""
104
+ models: List[Dict[str, str]]
105
+
106
+
107
+ # =============================================================================
108
+ # In-Memory State Storage
109
+ # =============================================================================
110
+
111
+ @dataclass
112
+ class VolumeState:
113
+ """State for a single uploaded volume."""
114
+ volume_id: str
115
+ array: np.ndarray
116
+ spacing: Tuple[float, float, float]
117
+ affine: Optional[np.ndarray] = None
118
+ mask: Optional[np.ndarray] = None
119
+ last_model_id: Optional[str] = None
120
+ metadata: Dict[str, Any] = field(default_factory=dict)
121
+
122
+
123
+ class StateManager:
124
+ """Simple in-memory state manager for volumes and masks.
125
+
126
+ Single-user assumption: no concurrency handling.
127
+ """
128
+
129
+ def __init__(self):
130
+ self._volumes: Dict[str, VolumeState] = {}
131
+
132
+ def create_volume(
133
+ self,
134
+ array: np.ndarray,
135
+ spacing: Tuple[float, float, float],
136
+ affine: Optional[np.ndarray] = None,
137
+ ) -> str:
138
+ """Store a new volume and return its ID."""
139
+ volume_id = str(uuid.uuid4())[:8] # Short ID for convenience
140
+ self._volumes[volume_id] = VolumeState(
141
+ volume_id=volume_id,
142
+ array=array,
143
+ spacing=spacing,
144
+ affine=affine,
145
+ )
146
+ logger.info(f"Created volume {volume_id} with shape {array.shape}")
147
+ return volume_id
148
+
149
+ def get_volume(self, volume_id: str) -> Optional[VolumeState]:
150
+ """Retrieve a volume by ID."""
151
+ return self._volumes.get(volume_id)
152
+
153
+ def update_mask(
154
+ self,
155
+ volume_id: str,
156
+ mask: np.ndarray,
157
+ model_id: str,
158
+ ) -> None:
159
+ """Update the segmentation mask for a volume."""
160
+ if volume_id not in self._volumes:
161
+ raise KeyError(f"Volume not found: {volume_id}")
162
+ self._volumes[volume_id].mask = mask
163
+ self._volumes[volume_id].last_model_id = model_id
164
+ logger.info(f"Updated mask for volume {volume_id} using model {model_id}")
165
+
166
+ def delete_volume(self, volume_id: str) -> bool:
167
+ """Delete a volume and its associated mask."""
168
+ if volume_id in self._volumes:
169
+ del self._volumes[volume_id]
170
+ logger.info(f"Deleted volume {volume_id}")
171
+ return True
172
+ return False
173
+
174
+ def list_volumes(self) -> List[str]:
175
+ """List all stored volume IDs."""
176
+ return list(self._volumes.keys())
177
+
178
+
179
+ # Global state manager instance
180
+ _state_manager = StateManager()
181
+
182
+
183
+ # =============================================================================
184
+ # Helper Functions
185
+ # =============================================================================
186
+
187
+ def _convert_prompts_to_orchestrator(prompts_model: Optional[PromptsModel]):
188
+ """Convert Pydantic PromptsModel to orchestrator Prompts object."""
189
+ if prompts_model is None or len(prompts_model.items) == 0:
190
+ return None
191
+
192
+ # Import here to avoid circular imports and side effects
193
+ from seg_app.inference.orchestrator import Prompts, PromptType, Prompt
194
+
195
+ prompts = Prompts()
196
+ for p in prompts_model.items:
197
+ if p.prompt_type == PromptTypeEnum.point_positive:
198
+ prompt_type = PromptType.POINT_POSITIVE
199
+ elif p.prompt_type == PromptTypeEnum.point_negative:
200
+ prompt_type = PromptType.POINT_NEGATIVE
201
+ else:
202
+ prompt_type = PromptType.BOUNDING_BOX
203
+
204
+ prompts.items.append(Prompt(prompt_type, tuple(p.coordinates)))
205
+
206
+ return prompts
207
+
208
+
209
+ def _get_available_model_ids() -> List[str]:
210
+ """Get list of valid model IDs."""
211
+ # Import here to avoid side effects
212
+ from seg_app.inference.orchestrator import get_available_models
213
+ return [m["id"] for m in get_available_models()]
214
+
215
+
216
+ # =============================================================================
217
+ # FastAPI Application Factory
218
+ # =============================================================================
219
+
220
+ def create_api_app() -> FastAPI:
221
+ """Create and configure the FastAPI application.
222
+
223
+ Returns:
224
+ Configured FastAPI app instance
225
+ """
226
+ app = FastAPI(
227
+ title="Brain Lesion Segmentation API",
228
+ description=(
229
+ "REST API for 3D medical image segmentation. "
230
+ "Supports volume upload, segmentation, and interactive refinement."
231
+ ),
232
+ version="1.0.0",
233
+ docs_url="/docs",
234
+ redoc_url="/redoc",
235
+ )
236
+
237
+ # Configure CORS for local frontend development
238
+ app.add_middleware(
239
+ CORSMiddleware,
240
+ allow_origins=[
241
+ "http://localhost:3000", # React dev server
242
+ "http://localhost:5173", # Vite dev server
243
+ "http://localhost:8080", # Generic dev server
244
+ "http://127.0.0.1:3000",
245
+ "http://127.0.0.1:5173",
246
+ "http://127.0.0.1:8080",
247
+ ],
248
+ allow_credentials=True,
249
+ allow_methods=["*"],
250
+ allow_headers=["*"],
251
+ )
252
+
253
+ # -------------------------------------------------------------------------
254
+ # Health Check
255
+ # -------------------------------------------------------------------------
256
+
257
+ @app.get("/health")
258
+ async def health_check():
259
+ """Health check endpoint."""
260
+ return {"status": "healthy"}
261
+
262
+ # -------------------------------------------------------------------------
263
+ # Model Information
264
+ # -------------------------------------------------------------------------
265
+
266
+ @app.get("/models", response_model=AvailableModelsResponse)
267
+ async def list_models():
268
+ """Get list of available segmentation models."""
269
+ from seg_app.inference.orchestrator import get_available_models
270
+ return AvailableModelsResponse(models=get_available_models())
271
+
272
+ # -------------------------------------------------------------------------
273
+ # Volume Upload
274
+ # -------------------------------------------------------------------------
275
+
276
+ @app.post("/volume/upload", response_model=VolumeUploadResponse)
277
+ async def upload_volume(file: UploadFile = File(...)):
278
+ """Upload a 3D medical volume (NIfTI format).
279
+
280
+ Accepts .nii or .nii.gz files.
281
+ Returns volume metadata and a volume_id for subsequent operations.
282
+ """
283
+ # Validate file extension
284
+ filename = file.filename or ""
285
+ if not (filename.endswith(".nii") or filename.endswith(".nii.gz")):
286
+ raise HTTPException(
287
+ status_code=400,
288
+ detail="Invalid file format. Please upload a NIfTI file (.nii or .nii.gz)"
289
+ )
290
+
291
+ try:
292
+ # Save uploaded file to temporary location
293
+ suffix = ".nii.gz" if filename.endswith(".nii.gz") else ".nii"
294
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
295
+ content = await file.read()
296
+ tmp.write(content)
297
+ tmp_path = tmp.name
298
+
299
+ # Load volume using existing I/O utilities
300
+ from seg_app.data.io import load_nifti
301
+ volume_data = load_nifti(tmp_path)
302
+
303
+ # Clean up temp file
304
+ Path(tmp_path).unlink(missing_ok=True)
305
+
306
+ # Store in state manager
307
+ volume_id = _state_manager.create_volume(
308
+ array=volume_data.array,
309
+ spacing=volume_data.spacing,
310
+ affine=volume_data.affine,
311
+ )
312
+
313
+ return VolumeUploadResponse(
314
+ volume_id=volume_id,
315
+ shape=volume_data.array.shape,
316
+ spacing=volume_data.spacing,
317
+ )
318
+
319
+ except Exception as e:
320
+ logger.error(f"Failed to load volume: {e}")
321
+ raise HTTPException(
322
+ status_code=400,
323
+ detail=f"Failed to load volume: {str(e)}"
324
+ )
325
+
326
+ # -------------------------------------------------------------------------
327
+ # Segmentation
328
+ # -------------------------------------------------------------------------
329
+
330
+ @app.post("/segment", response_model=SegmentationResponse)
331
+ async def run_segmentation(request: SegmentationRequest):
332
+ """Run segmentation on an uploaded volume.
333
+
334
+ Requires volume_id from a previous upload.
335
+ Optionally accepts prompts for SAM-based models.
336
+ """
337
+ # Validate volume exists
338
+ volume_state = _state_manager.get_volume(request.volume_id)
339
+ if volume_state is None:
340
+ raise HTTPException(
341
+ status_code=404,
342
+ detail=f"Volume not found: {request.volume_id}. Please upload a volume first."
343
+ )
344
+
345
+ # Validate model ID
346
+ valid_models = _get_available_model_ids()
347
+ if request.model_id not in valid_models:
348
+ raise HTTPException(
349
+ status_code=400,
350
+ detail=f"Invalid model_id: {request.model_id}. Valid options: {valid_models}"
351
+ )
352
+
353
+ try:
354
+ # Import orchestrator here to avoid side effects
355
+ from seg_app.inference.orchestrator import run_segmentation as orchestrator_run
356
+
357
+ # Convert prompts
358
+ prompts = _convert_prompts_to_orchestrator(request.prompts)
359
+
360
+ # Determine if this is refinement mode
361
+ existing_mask = None
362
+ if prompts is not None and volume_state.mask is not None:
363
+ existing_mask = volume_state.mask
364
+
365
+ # Run segmentation
366
+ result = orchestrator_run(
367
+ volume=volume_state.array,
368
+ task_name="brain_lesion",
369
+ prompts=prompts,
370
+ existing_mask=existing_mask,
371
+ spacing=volume_state.spacing,
372
+ model_id=request.model_id,
373
+ )
374
+
375
+ # Store mask
376
+ _state_manager.update_mask(
377
+ request.volume_id,
378
+ result.mask,
379
+ result.model_id,
380
+ )
381
+
382
+ return SegmentationResponse(
383
+ volume_id=request.volume_id,
384
+ model_id=result.model_id,
385
+ task_name=result.task_name,
386
+ mask_shape=result.mask.shape,
387
+ metadata=result.metadata,
388
+ )
389
+
390
+ except ValueError as e:
391
+ # User-recoverable errors (e.g., SAM requires prompts)
392
+ raise HTTPException(status_code=400, detail=str(e))
393
+ except Exception as e:
394
+ logger.error(f"Segmentation failed: {e}")
395
+ raise HTTPException(
396
+ status_code=500,
397
+ detail=f"Segmentation failed: {str(e)}"
398
+ )
399
+
400
+ # -------------------------------------------------------------------------
401
+ # Refinement
402
+ # -------------------------------------------------------------------------
403
+
404
+ @app.post("/refine", response_model=SegmentationResponse)
405
+ async def refine_segmentation(request: RefineRequest):
406
+ """Refine an existing segmentation using interactive prompts.
407
+
408
+ Requires a previous segmentation to exist for the volume.
409
+ Uses SAM-Med3D for refinement.
410
+ """
411
+ # Validate volume exists
412
+ volume_state = _state_manager.get_volume(request.volume_id)
413
+ if volume_state is None:
414
+ raise HTTPException(
415
+ status_code=404,
416
+ detail=f"Volume not found: {request.volume_id}. Please upload a volume first."
417
+ )
418
+
419
+ # Validate existing mask
420
+ if volume_state.mask is None:
421
+ raise HTTPException(
422
+ status_code=400,
423
+ detail="No existing segmentation found. Run /segment first before refinement."
424
+ )
425
+
426
+ # Validate prompts
427
+ if request.prompts is None or len(request.prompts.items) == 0:
428
+ raise HTTPException(
429
+ status_code=400,
430
+ detail="Prompts are required for refinement."
431
+ )
432
+
433
+ try:
434
+ from seg_app.inference.orchestrator import run_segmentation as orchestrator_run
435
+
436
+ prompts = _convert_prompts_to_orchestrator(request.prompts)
437
+
438
+ # Run refinement (SAM mode)
439
+ result = orchestrator_run(
440
+ volume=volume_state.array,
441
+ task_name="brain_lesion",
442
+ prompts=prompts,
443
+ existing_mask=volume_state.mask,
444
+ spacing=volume_state.spacing,
445
+ model_id="medical-sam-3d",
446
+ full_reinference=False,
447
+ )
448
+
449
+ # Update mask
450
+ _state_manager.update_mask(
451
+ request.volume_id,
452
+ result.mask,
453
+ result.model_id,
454
+ )
455
+
456
+ return SegmentationResponse(
457
+ volume_id=request.volume_id,
458
+ model_id=result.model_id,
459
+ task_name=result.task_name,
460
+ mask_shape=result.mask.shape,
461
+ status="refined",
462
+ metadata=result.metadata,
463
+ )
464
+
465
+ except ValueError as e:
466
+ raise HTTPException(status_code=400, detail=str(e))
467
+ except Exception as e:
468
+ logger.error(f"Refinement failed: {e}")
469
+ raise HTTPException(
470
+ status_code=500,
471
+ detail=f"Refinement failed: {str(e)}"
472
+ )
473
+
474
+ # -------------------------------------------------------------------------
475
+ # Mask Retrieval
476
+ # -------------------------------------------------------------------------
477
+
478
+ @app.get("/mask/{volume_id}")
479
+ async def get_mask(volume_id: str, format: str = "npy"):
480
+ """Retrieve the segmentation mask for a volume.
481
+
482
+ Args:
483
+ volume_id: ID of the volume
484
+ format: Output format - 'npy' (compressed numpy) or 'nifti'
485
+
486
+ Returns:
487
+ Binary mask data in requested format
488
+ """
489
+ # Validate volume exists
490
+ volume_state = _state_manager.get_volume(volume_id)
491
+ if volume_state is None:
492
+ raise HTTPException(
493
+ status_code=404,
494
+ detail=f"Volume not found: {volume_id}"
495
+ )
496
+
497
+ # Validate mask exists
498
+ if volume_state.mask is None:
499
+ raise HTTPException(
500
+ status_code=404,
501
+ detail=f"No segmentation mask found for volume {volume_id}. Run /segment first."
502
+ )
503
+
504
+ if format == "npy":
505
+ # Return compressed numpy array
506
+ buffer = io.BytesIO()
507
+ np.savez_compressed(buffer, mask=volume_state.mask)
508
+ buffer.seek(0)
509
+ return Response(
510
+ content=buffer.getvalue(),
511
+ media_type="application/octet-stream",
512
+ headers={
513
+ "Content-Disposition": f"attachment; filename=mask_{volume_id}.npz"
514
+ }
515
+ )
516
+
517
+ elif format == "nifti":
518
+ # Return as NIfTI file
519
+ try:
520
+ import nibabel as nib
521
+
522
+ # Create NIfTI image with affine
523
+ affine = volume_state.affine
524
+ if affine is None:
525
+ # Create identity affine with spacing
526
+ affine = np.diag([*volume_state.spacing, 1.0])
527
+
528
+ nifti_img = nib.Nifti1Image(
529
+ volume_state.mask.astype(np.uint8),
530
+ affine=affine
531
+ )
532
+
533
+ buffer = io.BytesIO()
534
+ nib.save(nifti_img, buffer)
535
+ buffer.seek(0)
536
+
537
+ return Response(
538
+ content=buffer.getvalue(),
539
+ media_type="application/octet-stream",
540
+ headers={
541
+ "Content-Disposition": f"attachment; filename=mask_{volume_id}.nii.gz"
542
+ }
543
+ )
544
+ except Exception as e:
545
+ logger.error(f"Failed to create NIfTI: {e}")
546
+ raise HTTPException(
547
+ status_code=500,
548
+ detail=f"Failed to create NIfTI file: {str(e)}"
549
+ )
550
+
551
+ else:
552
+ raise HTTPException(
553
+ status_code=400,
554
+ detail=f"Invalid format: {format}. Supported: 'npy', 'nifti'"
555
+ )
556
+
557
+ # -------------------------------------------------------------------------
558
+ # Volume Management
559
+ # -------------------------------------------------------------------------
560
+
561
+ @app.get("/volumes")
562
+ async def list_volumes():
563
+ """List all uploaded volume IDs."""
564
+ return {"volume_ids": _state_manager.list_volumes()}
565
+
566
+ @app.delete("/volume/{volume_id}")
567
+ async def delete_volume(volume_id: str):
568
+ """Delete an uploaded volume and its mask."""
569
+ if _state_manager.delete_volume(volume_id):
570
+ return {"message": f"Volume {volume_id} deleted"}
571
+ else:
572
+ raise HTTPException(
573
+ status_code=404,
574
+ detail=f"Volume not found: {volume_id}"
575
+ )
576
+
577
+ return app
578
+
579
+
580
+ # Create app instance for uvicorn
581
+ app = create_api_app()
582
+
583
+
584
+ if __name__ == "__main__":
585
+ import uvicorn
586
+ uvicorn.run(app, host="127.0.0.1", port=8000)