maregu2023 commited on
Commit
d2c2086
·
1 Parent(s): cf575e8

feat: Add nnUNet integration, Dice metrics, and split INFO panel with overlap viewer

Browse files

- Add nnUNet v2 wrapper (nnunet_wrapper.py) for brain lesion segmentation
- Register nnUNet and fix unet3d-brain-tumor model in registry
- Add Dice score, IoU, sensitivity, precision metrics computation
- Add GT upload endpoint and metrics API endpoints
- Split INFO panel: left (metrics), right (mask comparison viewer)
- Overlap viewer shows TP (green), FP (red), FN (blue) visualization
- Add scroll-based slice navigation for overlap viewer
- Set 3D U-Net (Baseline) as default model selection
- Reset mask/metrics/overlap when uploading new volume
- Update README with local run instructions

README.md CHANGED
@@ -363,3 +363,176 @@ Key endpoints:
363
  - `POST /refine` – Refine segmentation with additional prompts
364
  - `GET /mask/{volume_id}/data` – Download raw mask data
365
  - `GET /health` – Health check endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  - `POST /refine` – Refine segmentation with additional prompts
364
  - `GET /mask/{volume_id}/data` – Download raw mask data
365
  - `GET /health` – Health check endpoint
366
+
367
+ ---
368
+
369
+ ## 🧬 Integrating Your Own nnUNet Model
370
+
371
+ This section explains how to integrate a trained nnUNet v2 model into the web application as a baseline model option.
372
+
373
+ ### Prerequisites
374
+
375
+ 1. **Trained nnUNet model**: You need a completed nnUNet v2 training run with:
376
+ - Code location (for reference): e.g., `/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet/`
377
+ - **Checkpoints directory**: e.g., `/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/`
378
+
379
+ 2. **nnUNet v2 installed**: The nnunetv2 package must be installed in your conda environment
380
+
381
+ ### Step 1: Install nnUNet v2
382
+
383
+ First, install nnUNet v2 in your conda environment:
384
+
385
+ ```powershell
386
+ # Activate your environment
387
+ conda activate seg_app
388
+
389
+ # Install nnUNet v2
390
+ pip install nnunetv2
391
+ ```
392
+
393
+ Or uncomment the line in `requirements.txt`:
394
+
395
+ ```pip
396
+ # nnunetv2>=2.2
397
+ ```
398
+
399
+ And run:
400
+ ```powershell
401
+ pip install -r requirements.txt
402
+ ```
403
+
404
+ ### Step 2: Locate Your nnUNet Checkpoint Path
405
+
406
+ nnUNet training creates a specific folder structure. You need the path to the **trainer output folder**:
407
+
408
+ ```
409
+ nnUNet_results/
410
+ └── Dataset###_Name/ # e.g., Dataset001_BrainLesion
411
+ └── nnUNetTrainer__nnUNetPlans__3d_fullres/ # ← THIS IS THE PATH YOU NEED
412
+ ├── plans.json # Training plans
413
+ ├── dataset.json # Dataset configuration
414
+ ├── fold_0/
415
+ │ ├── checkpoint_final.pth # Model weights
416
+ │ └── checkpoint_best.pth # Best validation weights
417
+ ├── fold_1/
418
+ │ └── ...
419
+ └── ...
420
+ ```
421
+
422
+ **Your checkpoint path** should point to the trainer folder, for example:
423
+ ```
424
+ /mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres
425
+ ```
426
+
427
+ ### Step 3: Configure nnUNet in Settings
428
+
429
+ Edit the file `seg_app/config/settings.py` and update the `NNUNET_CONFIG`:
430
+
431
+ ```python
432
+ # Global nnUNet configuration instance
433
+ NNUNET_CONFIG = nnUNetConfig(
434
+ # Set the path to your trained nnUNet model folder
435
+ checkpoint_path="/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres",
436
+
437
+ # Optional: which folds to use for inference
438
+ # "all" = use all available folds (ensemble), or [0] for single fold
439
+ use_folds="all",
440
+
441
+ # Optional: test-time augmentation (mirroring)
442
+ # True = more accurate but slower, False = faster inference
443
+ use_mirroring=True,
444
+
445
+ # Optional: customize the display name in the UI dropdown
446
+ display_name="nnU-Net (Brain Lesion)",
447
+ )
448
+ ```
449
+
450
+ ### Step 4: Verify the Integration
451
+
452
+ After configuration, restart the application and verify nnUNet appears in the model dropdown:
453
+
454
+ ```powershell
455
+ # For Gradio UI
456
+ python app.py
457
+
458
+ # For VTK.js Slicer (FastAPI backend)
459
+ uvicorn seg_app.backend.api:app --reload --port 8000
460
+ ```
461
+
462
+ Open the web interface and check that **"nnU-Net (Brain Lesion)"** appears in the model selection dropdown.
463
+
464
+ ### Step 5: Test Inference
465
+
466
+ 1. Upload a NIfTI volume (`.nii` or `.nii.gz`)
467
+ 2. Select **"nnU-Net (Brain Lesion)"** from the model dropdown
468
+ 3. Click **"Run Segmentation"**
469
+ 4. View the segmentation overlay and volume metrics
470
+
471
+ ### nnUNet Configuration Options
472
+
473
+ | Option | Default | Description |
474
+ |--------|---------|-------------|
475
+ | `checkpoint_path` | `None` | **Required.** Path to nnUNet trainer output folder |
476
+ | `use_folds` | `"all"` | Which folds to use. `"all"` for ensemble, `[0]` for single fold |
477
+ | `use_mirroring` | `True` | Test-time augmentation. Improves accuracy but ~4x slower |
478
+ | `display_name` | `"nnU-Net (Brain Lesion)"` | Name shown in UI dropdown |
479
+
480
+ ### Advanced: Multiple nnUNet Models
481
+
482
+ To add multiple nnUNet models for different tasks, you can modify the registration logic in `seg_app/models/nnunet_wrapper.py`. The `register_nnunet()` function can be extended to register multiple models:
483
+
484
+ ```python
485
+ # Example: Register multiple nnUNet models
486
+ def register_nnunet() -> None:
487
+ from seg_app.inference.model_registry import register_model
488
+
489
+ # Model 1: Brain Lesion
490
+ brain_config = ModelConfig(
491
+ model_id="nnunet-brain-lesion",
492
+ local_path="/path/to/nnUNet_results/DatasetBrain/nnUNetTrainer__nnUNetPlans__3d_fullres",
493
+ device="cuda",
494
+ )
495
+ register_model("nnunet-brain-lesion", nnUNetWrapper, brain_config)
496
+
497
+ # Model 2: Liver (example)
498
+ liver_config = ModelConfig(
499
+ model_id="nnunet-liver",
500
+ local_path="/path/to/nnUNet_results/DatasetLiver/nnUNetTrainer__nnUNetPlans__3d_fullres",
501
+ device="cuda",
502
+ )
503
+ register_model("nnunet-liver", nnUNetWrapper, liver_config)
504
+ ```
505
+
506
+ ### Troubleshooting nnUNet Integration
507
+
508
+ **"nnunetv2 is not installed"**
509
+ ```powershell
510
+ pip install nnunetv2
511
+ ```
512
+
513
+ **"plans.json not found"**
514
+ - Ensure `checkpoint_path` points to the trainer folder (not the dataset folder)
515
+ - The path should contain `plans.json` or `nnUNetPlans.json`
516
+
517
+ **"checkpoint_final.pth not found"**
518
+ - Verify training completed successfully
519
+ - Check if only `checkpoint_best.pth` exists (modify the wrapper to use it)
520
+
521
+ **"Model not appearing in dropdown"**
522
+ - Check that `checkpoint_path` is not `None` in `NNUNET_CONFIG`
523
+ - Look for warnings in the terminal when starting the app
524
+
525
+ **Out of memory during inference**
526
+ - Reduce `use_mirroring` to `False` (reduces memory by ~4x)
527
+ - Use fewer folds: `use_folds=[0]` instead of `"all"`
528
+ - Reduce input volume size
529
+
530
+ ### Files Modified for nnUNet Integration
531
+
532
+ | File | Purpose |
533
+ |------|---------|
534
+ | [seg_app/models/nnunet_wrapper.py](seg_app/models/nnunet_wrapper.py) | nnUNet model wrapper implementing BaseModel interface |
535
+ | [seg_app/config/settings.py](seg_app/config/settings.py) | nnUNet configuration (checkpoint path, inference settings) |
536
+ | [seg_app/inference/model_registry.py](seg_app/inference/model_registry.py) | Registers nnUNet model during lazy initialization |
537
+ | [seg_app/inference/orchestrator.py](seg_app/inference/orchestrator.py) | Adds nnUNet to available models list |
538
+ | [requirements.txt](requirements.txt) | Optional nnunetv2 dependency |
app.py CHANGED
@@ -131,3 +131,27 @@ if __name__ == "__main__":
131
  server_port=7869,
132
  share=False,
133
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  server_port=7869,
132
  share=False,
133
  )
134
+
135
+
136
+ """
137
+
138
+ >> First Launch API server backend: Start the FastAPI backend with uvicorn
139
+ # uvicorn seg_app.api.app:app --reload --host 127.0.0.1 --port 8000
140
+ uvicorn seg_app.backend.api:app --reload --host 127.0.0.1 --port 8000
141
+
142
+ >> Second, Start the FastAPI Frontend:
143
+ cd seg_app/ui_slicer
144
+ python serve_frontend.py
145
+
146
+
147
+ >> Launch Gradio App for seg_app.
148
+ python app.py
149
+
150
+ # Terminal 1: FastAPI Backend
151
+ uvicorn seg_app.backend.api:app --reload --host 127.0.0.1 --port 8000
152
+
153
+ # Terminal 2: Frontend Server
154
+ cd seg_app/ui_slicer
155
+ python serve_frontend.py
156
+
157
+ """
requirements.txt CHANGED
@@ -42,3 +42,9 @@ huggingface_hub>=0.20.0,<0.28.0
42
  # =============================================================================
43
  einops>=0.6.0
44
  timm>=0.9.0
 
 
 
 
 
 
 
42
  # =============================================================================
43
  einops>=0.6.0
44
  timm>=0.9.0
45
+
46
+ # =============================================================================
47
+ # nnUNet v2 (optional - for local nnUNet models)
48
+ # =============================================================================
49
+ # Uncomment to enable nnUNet support:
50
+ # nnunetv2>=2.2
seg_app/backend/api.py CHANGED
@@ -108,6 +108,30 @@ class AvailableModelsResponse(BaseModel):
108
  models: List[Dict[str, str]]
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # =============================================================================
112
  # In-Memory State Storage
113
  # =============================================================================
@@ -120,6 +144,7 @@ class VolumeState:
120
  spacing: Tuple[float, float, float]
121
  affine: Optional[np.ndarray] = None
122
  mask: Optional[np.ndarray] = None
 
123
  last_model_id: Optional[str] = None
124
  metadata: Dict[str, Any] = field(default_factory=dict)
125
 
@@ -175,6 +200,17 @@ class StateManager:
175
  return True
176
  return False
177
 
 
 
 
 
 
 
 
 
 
 
 
178
  def list_volumes(self) -> List[str]:
179
  """List all stored volume IDs."""
180
  return list(self._volumes.keys())
@@ -692,6 +728,181 @@ def create_api_app() -> FastAPI:
692
  }
693
  )
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  return app
696
 
697
 
@@ -704,13 +915,11 @@ if __name__ == "__main__":
704
  uvicorn.run(app, host="127.0.0.1", port=8000)
705
 
706
 
707
- """
708
- Terminal 1 - Start Backend FIRST:
709
- cd e:\academia\research_projects\web_app
710
- conda activate seg_app
711
- uvicorn seg_app.backend.api:app --reload --port 8000
712
-
713
- Terminal 2 - Start Frontend SECOND:
714
- cd e:\academia\research_projects\web_app\seg_app\ui_slicer
715
- python serve_frontend.py
716
- """
 
108
  models: List[Dict[str, str]]
109
 
110
 
111
+ class MetricsResponse(BaseModel):
112
+ """Segmentation metrics including Dice score."""
113
+ volume_id: str
114
+ has_ground_truth: bool
115
+ voxel_count: int
116
+ volume_mm3: float
117
+ volume_ml: float
118
+ # Dice metrics (only if ground truth is available)
119
+ dice_score: Optional[float] = None
120
+ iou_score: Optional[float] = None
121
+ sensitivity: Optional[float] = None
122
+ precision: Optional[float] = None
123
+ true_positives: Optional[int] = None
124
+ false_positives: Optional[int] = None
125
+ false_negatives: Optional[int] = None
126
+
127
+
128
+ class GroundTruthUploadResponse(BaseModel):
129
+ """Response after uploading ground truth mask."""
130
+ volume_id: str
131
+ gt_shape: Tuple[int, int, int]
132
+ message: str = "Ground truth uploaded successfully"
133
+
134
+
135
  # =============================================================================
136
  # In-Memory State Storage
137
  # =============================================================================
 
144
  spacing: Tuple[float, float, float]
145
  affine: Optional[np.ndarray] = None
146
  mask: Optional[np.ndarray] = None
147
+ ground_truth: Optional[np.ndarray] = None # For Dice score computation
148
  last_model_id: Optional[str] = None
149
  metadata: Dict[str, Any] = field(default_factory=dict)
150
 
 
200
  return True
201
  return False
202
 
203
+ def update_ground_truth(
204
+ self,
205
+ volume_id: str,
206
+ ground_truth: np.ndarray,
207
+ ) -> None:
208
+ """Update the ground truth mask for a volume."""
209
+ if volume_id not in self._volumes:
210
+ raise KeyError(f"Volume not found: {volume_id}")
211
+ self._volumes[volume_id].ground_truth = ground_truth
212
+ logger.info(f"Updated ground truth for volume {volume_id}")
213
+
214
  def list_volumes(self) -> List[str]:
215
  """List all stored volume IDs."""
216
  return list(self._volumes.keys())
 
728
  }
729
  )
730
 
731
+ # -------------------------------------------------------------------------
732
+ # Ground Truth Upload (for Dice score computation)
733
+ # -------------------------------------------------------------------------
734
+
735
+ @app.post("/ground-truth/{volume_id}", response_model=GroundTruthUploadResponse)
736
+ async def upload_ground_truth(volume_id: str, file: UploadFile = File(...)):
737
+ """Upload a ground truth segmentation mask for Dice score computation.
738
+
739
+ The ground truth mask must match the shape of the uploaded volume.
740
+ Accepts .nii or .nii.gz files.
741
+ """
742
+ # Validate volume exists
743
+ volume_state = _state_manager.get_volume(volume_id)
744
+ if volume_state is None:
745
+ raise HTTPException(
746
+ status_code=404,
747
+ detail=f"Volume not found: {volume_id}. Please upload volume first."
748
+ )
749
+
750
+ # Validate file extension
751
+ filename = file.filename or ""
752
+ if not (filename.endswith(".nii") or filename.endswith(".nii.gz")):
753
+ raise HTTPException(
754
+ status_code=400,
755
+ detail="Invalid file format. Please upload a NIfTI file (.nii or .nii.gz)"
756
+ )
757
+
758
+ try:
759
+ # Save to temp file
760
+ suffix = ".nii.gz" if filename.endswith(".nii.gz") else ".nii"
761
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
762
+ content = await file.read()
763
+ tmp.write(content)
764
+ tmp_path = tmp.name
765
+
766
+ # Load with nibabel
767
+ import nibabel as nib
768
+ img = nib.load(tmp_path)
769
+ gt_array = np.asarray(img.dataobj).copy()
770
+ del img
771
+
772
+ # Clean up temp file
773
+ try:
774
+ Path(tmp_path).unlink(missing_ok=True)
775
+ except PermissionError:
776
+ logger.warning(f"Could not delete temp file: {tmp_path}")
777
+
778
+ # Handle 4D with trivial dimension
779
+ if gt_array.ndim == 4 and gt_array.shape[-1] == 1:
780
+ gt_array = gt_array[..., 0]
781
+
782
+ # Validate shape matches volume
783
+ if gt_array.shape != volume_state.array.shape:
784
+ raise HTTPException(
785
+ status_code=400,
786
+ detail=f"Shape mismatch: ground truth {gt_array.shape} vs volume {volume_state.array.shape}"
787
+ )
788
+
789
+ # Binarize (any non-zero is foreground)
790
+ gt_array = (gt_array > 0).astype(np.uint8)
791
+
792
+ # Store ground truth
793
+ _state_manager.update_ground_truth(volume_id, gt_array)
794
+
795
+ return GroundTruthUploadResponse(
796
+ volume_id=volume_id,
797
+ gt_shape=gt_array.shape,
798
+ )
799
+
800
+ except HTTPException:
801
+ raise
802
+ except Exception as e:
803
+ logger.error(f"Failed to load ground truth: {e}")
804
+ raise HTTPException(
805
+ status_code=400,
806
+ detail=f"Failed to load ground truth: {str(e)}"
807
+ )
808
+
809
+ # -------------------------------------------------------------------------
810
+ # Ground Truth Mask Data Endpoint (raw bytes for overlap viewer)
811
+ # -------------------------------------------------------------------------
812
+
813
+ @app.get("/ground-truth/{volume_id}/data")
814
+ async def get_ground_truth_data(volume_id: str):
815
+ """Get raw ground truth mask data as Uint8 array.
816
+
817
+ Returns the ground truth data as a raw binary Uint8Array,
818
+ suitable for direct use in overlap visualization.
819
+ """
820
+ volume_state = _state_manager.get_volume(volume_id)
821
+ if volume_state is None:
822
+ raise HTTPException(
823
+ status_code=404,
824
+ detail=f"Volume not found: {volume_id}"
825
+ )
826
+
827
+ if volume_state.ground_truth is None:
828
+ raise HTTPException(
829
+ status_code=404,
830
+ detail=f"No ground truth found for volume {volume_id}"
831
+ )
832
+
833
+ # Convert to uint8 and return as raw bytes
834
+ data = volume_state.ground_truth.astype(np.uint8)
835
+
836
+ return Response(
837
+ content=data.tobytes(),
838
+ media_type="application/octet-stream",
839
+ headers={
840
+ "Content-Disposition": f"attachment; filename=gt_{volume_id}.raw",
841
+ "X-GT-Shape": ",".join(map(str, volume_state.ground_truth.shape)),
842
+ }
843
+ )
844
+
845
+ # -------------------------------------------------------------------------
846
+ # Metrics Endpoint (including Dice score)
847
+ # -------------------------------------------------------------------------
848
+
849
+ @app.get("/metrics/{volume_id}", response_model=MetricsResponse)
850
+ async def get_metrics(volume_id: str):
851
+ """Get segmentation metrics including Dice score if ground truth is available.
852
+
853
+ Returns volume, voxel count, and overlap metrics (Dice, IoU, etc.)
854
+ if a ground truth mask has been uploaded.
855
+ """
856
+ volume_state = _state_manager.get_volume(volume_id)
857
+ if volume_state is None:
858
+ raise HTTPException(
859
+ status_code=404,
860
+ detail=f"Volume not found: {volume_id}"
861
+ )
862
+
863
+ if volume_state.mask is None:
864
+ raise HTTPException(
865
+ status_code=400,
866
+ detail="No segmentation mask found. Run /segment first."
867
+ )
868
+
869
+ # Basic metrics
870
+ from seg_app.metrics.segmentation_metrics import (
871
+ compute_segmentation_metrics,
872
+ compute_metrics_with_ground_truth,
873
+ )
874
+
875
+ basic_metrics = compute_segmentation_metrics(
876
+ volume_state.mask, volume_state.spacing
877
+ )
878
+
879
+ voxel_count = sum(basic_metrics["voxel_counts"].values())
880
+
881
+ response = MetricsResponse(
882
+ volume_id=volume_id,
883
+ has_ground_truth=volume_state.ground_truth is not None,
884
+ voxel_count=voxel_count,
885
+ volume_mm3=basic_metrics["total_volume_mm3"],
886
+ volume_ml=basic_metrics["total_volume_ml"],
887
+ )
888
+
889
+ # Add Dice metrics if ground truth is available
890
+ if volume_state.ground_truth is not None:
891
+ gt_metrics = compute_metrics_with_ground_truth(
892
+ volume_state.mask,
893
+ volume_state.ground_truth,
894
+ volume_state.spacing,
895
+ )
896
+ response.dice_score = gt_metrics["dice_score"]
897
+ response.iou_score = gt_metrics["iou_score"]
898
+ response.sensitivity = gt_metrics["sensitivity"]
899
+ response.precision = gt_metrics["precision"]
900
+ response.true_positives = gt_metrics["true_positives"]
901
+ response.false_positives = gt_metrics["false_positives"]
902
+ response.false_negatives = gt_metrics["false_negatives"]
903
+
904
+ return response
905
+
906
  return app
907
 
908
 
 
915
  uvicorn.run(app, host="127.0.0.1", port=8000)
916
 
917
 
918
+ # Terminal 1 - Start Backend FIRST:
919
+ # cd e:/academia/research_projects/web_app
920
+ # conda activate seg_app
921
+ # uvicorn seg_app.backend.api:app --reload --port 8000
922
+ #
923
+ # Terminal 2 - Start Frontend SECOND:
924
+ # cd e:/academia/research_projects/web_app/seg_app/ui_slicer
925
+ # python serve_frontend.py
 
 
seg_app/config/settings.py CHANGED
@@ -76,6 +76,47 @@ class HFHubConfig:
76
  HF_HUB_CONFIG = HFHubConfig()
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  # =============================================================================
80
  # Application Settings
81
  # =============================================================================
 
76
  HF_HUB_CONFIG = HFHubConfig()
77
 
78
 
79
+ # =============================================================================
80
+ # nnUNet Configuration
81
+ # =============================================================================
82
+
83
+ @dataclass
84
+ class nnUNetConfig:
85
+ """Configuration for nnUNet models.
86
+
87
+ Set checkpoint_path to enable nnUNet in the model dropdown.
88
+
89
+ Attributes:
90
+ checkpoint_path: Path to the nnUNet training output folder.
91
+ This should point to the trainer directory, e.g.:
92
+ /path/to/nnUNet_results/DatasetXXX_Name/nnUNetTrainer__nnUNetPlans__3d_fullres
93
+ use_folds: Which folds to use for inference.
94
+ Options: "all" (default), or list of integers [0], [0, 1], etc.
95
+ use_mirroring: Whether to use test-time augmentation with mirroring.
96
+ Improves accuracy but increases inference time.
97
+ display_name: Name shown in the UI model dropdown.
98
+ """
99
+ # Path to nnUNet checkpoint folder
100
+ # Set this to your local nnUNet_results path to enable nnUNet
101
+ # Example: "/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres"
102
+ checkpoint_path: Optional[str] = None
103
+
104
+ # Inference settings
105
+ use_folds: str = "all" # "all" or specific fold like [0]
106
+ use_mirroring: bool = True # Test-time augmentation (slower but more accurate)
107
+
108
+ # Display settings
109
+ display_name: str = "nnU-Net (Brain Lesion)"
110
+
111
+
112
+ # Global nnUNet configuration instance
113
+ # Users should modify this path to point to their trained nnUNet model
114
+ NNUNET_CONFIG = nnUNetConfig(
115
+ # Uncomment and set this path to enable nnUNet:
116
+ # checkpoint_path="/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset###_Name/nnUNetTrainer__nnUNetPlans__3d_fullres",
117
+ )
118
+
119
+
120
  # =============================================================================
121
  # Application Settings
122
  # =============================================================================
seg_app/inference/model_registry.py CHANGED
@@ -232,4 +232,12 @@ def _ensure_models_registered() -> None:
232
  from seg_app.models.medical_sam_3d import register_medical_sam_3d
233
  register_medical_sam_3d()
234
 
 
 
 
 
 
 
 
 
235
  _models_registered = True
 
232
  from seg_app.models.medical_sam_3d import register_medical_sam_3d
233
  register_medical_sam_3d()
234
 
235
+ # Register nnUNet if configured (local checkpoint required)
236
+ try:
237
+ from seg_app.models.nnunet_wrapper import register_nnunet
238
+ register_nnunet()
239
+ except Exception as e:
240
+ import logging
241
+ logging.getLogger(__name__).debug(f"nnUNet registration skipped: {e}")
242
+
243
  _models_registered = True
seg_app/inference/orchestrator.py CHANGED
@@ -190,12 +190,24 @@ def get_available_models() -> List[Dict[str, str]]:
190
  """Get list of available models for brain lesion segmentation.
191
 
192
  Returns:
193
- List of dicts with 'id' and 'display_name' for each model
 
194
  """
195
- return [
 
 
196
  {"id": "medical-sam-3d", "display_name": "Medical SAM 3D (SA-Med3D-140K)"},
197
  {"id": "unet3d-brain-tumor", "display_name": "3D U-Net (Baseline)"},
198
  ]
 
 
 
 
 
 
 
 
 
199
 
200
 
201
  def get_task_info(task_name: str) -> Dict[str, Any]:
 
190
  """Get list of available models for brain lesion segmentation.
191
 
192
  Returns:
193
+ List of dicts with 'id' and 'display_name' for each model.
194
+ nnUNet is included only if configured in settings.
195
  """
196
+ from seg_app.config.settings import NNUNET_CONFIG
197
+
198
+ models = [
199
  {"id": "medical-sam-3d", "display_name": "Medical SAM 3D (SA-Med3D-140K)"},
200
  {"id": "unet3d-brain-tumor", "display_name": "3D U-Net (Baseline)"},
201
  ]
202
+
203
+ # Add nnUNet if checkpoint path is configured
204
+ if NNUNET_CONFIG.checkpoint_path is not None:
205
+ models.insert(0, {
206
+ "id": "nnunet-brain-lesion",
207
+ "display_name": NNUNET_CONFIG.display_name,
208
+ })
209
+
210
+ return models
211
 
212
 
213
  def get_task_info(task_name: str) -> Dict[str, Any]:
seg_app/metrics/segmentation_metrics.py CHANGED
@@ -102,6 +102,144 @@ def compute_segmentation_metrics(
102
  }
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def format_metrics_for_display(metrics: Dict[str, Any]) -> Dict[str, Any]:
106
  """Format metrics dictionary for UI display.
107
 
 
102
  }
103
 
104
 
105
+ def compute_dice_score(
106
+ prediction: np.ndarray,
107
+ ground_truth: np.ndarray,
108
+ smooth: float = 1e-6,
109
+ ) -> float:
110
+ """Compute the Dice similarity coefficient between prediction and ground truth.
111
+
112
+ Dice = 2 * |P ∩ G| / (|P| + |G|)
113
+
114
+ Args:
115
+ prediction: Predicted binary segmentation mask (D, H, W)
116
+ ground_truth: Ground truth binary segmentation mask (D, H, W)
117
+ smooth: Smoothing factor to avoid division by zero
118
+
119
+ Returns:
120
+ Dice score in range [0, 1], where 1 is perfect overlap
121
+
122
+ Raises:
123
+ ValueError: If masks have different shapes
124
+ """
125
+ if prediction.shape != ground_truth.shape:
126
+ raise ValueError(
127
+ f"Shape mismatch: prediction {prediction.shape} vs ground_truth {ground_truth.shape}"
128
+ )
129
+
130
+ # Binarize masks
131
+ pred_binary = (prediction > 0).astype(np.float32)
132
+ gt_binary = (ground_truth > 0).astype(np.float32)
133
+
134
+ # Compute intersection and unions
135
+ intersection = np.sum(pred_binary * gt_binary)
136
+ pred_sum = np.sum(pred_binary)
137
+ gt_sum = np.sum(gt_binary)
138
+
139
+ # Dice formula with smoothing
140
+ dice = (2.0 * intersection + smooth) / (pred_sum + gt_sum + smooth)
141
+
142
+ return float(dice)
143
+
144
+
145
+ def compute_iou_score(
146
+ prediction: np.ndarray,
147
+ ground_truth: np.ndarray,
148
+ smooth: float = 1e-6,
149
+ ) -> float:
150
+ """Compute the Intersection over Union (Jaccard index) between masks.
151
+
152
+ IoU = |P ∩ G| / |P ∪ G|
153
+
154
+ Args:
155
+ prediction: Predicted binary segmentation mask (D, H, W)
156
+ ground_truth: Ground truth binary segmentation mask (D, H, W)
157
+ smooth: Smoothing factor to avoid division by zero
158
+
159
+ Returns:
160
+ IoU score in range [0, 1], where 1 is perfect overlap
161
+ """
162
+ if prediction.shape != ground_truth.shape:
163
+ raise ValueError(
164
+ f"Shape mismatch: prediction {prediction.shape} vs ground_truth {ground_truth.shape}"
165
+ )
166
+
167
+ pred_binary = (prediction > 0).astype(np.float32)
168
+ gt_binary = (ground_truth > 0).astype(np.float32)
169
+
170
+ intersection = np.sum(pred_binary * gt_binary)
171
+ union = np.sum(pred_binary) + np.sum(gt_binary) - intersection
172
+
173
+ iou = (intersection + smooth) / (union + smooth)
174
+
175
+ return float(iou)
176
+
177
+
178
+ def compute_metrics_with_ground_truth(
179
+ prediction: np.ndarray,
180
+ ground_truth: np.ndarray,
181
+ spacing: Tuple[float, float, float],
182
+ ) -> Dict[str, Any]:
183
+ """Compute comprehensive metrics comparing prediction to ground truth.
184
+
185
+ Args:
186
+ prediction: Predicted segmentation mask (D, H, W)
187
+ ground_truth: Ground truth segmentation mask (D, H, W)
188
+ spacing: Voxel spacing in mm as (depth, height, width)
189
+
190
+ Returns:
191
+ Dict containing all metrics including Dice, IoU, volumes, etc.
192
+ """
193
+ # Basic metrics for prediction
194
+ pred_metrics = compute_segmentation_metrics(prediction, spacing)
195
+ gt_metrics = compute_segmentation_metrics(ground_truth, spacing)
196
+
197
+ # Overlap metrics
198
+ dice = compute_dice_score(prediction, ground_truth)
199
+ iou = compute_iou_score(prediction, ground_truth)
200
+
201
+ # Additional statistics
202
+ pred_binary = prediction > 0
203
+ gt_binary = ground_truth > 0
204
+
205
+ tp = np.sum(pred_binary & gt_binary) # True positives
206
+ fp = np.sum(pred_binary & ~gt_binary) # False positives
207
+ fn = np.sum(~pred_binary & gt_binary) # False negatives
208
+ tn = np.sum(~pred_binary & ~gt_binary) # True negatives
209
+
210
+ # Derived metrics
211
+ sensitivity = tp / (tp + fn + 1e-6) # Recall / TPR
212
+ specificity = tn / (tn + fp + 1e-6) # TNR
213
+ precision = tp / (tp + fp + 1e-6) # PPV
214
+
215
+ return {
216
+ # Overlap metrics
217
+ "dice_score": round(dice, 4),
218
+ "iou_score": round(iou, 4),
219
+
220
+ # Classification metrics
221
+ "sensitivity": round(sensitivity, 4),
222
+ "specificity": round(specificity, 4),
223
+ "precision": round(precision, 4),
224
+
225
+ # Voxel counts
226
+ "true_positives": int(tp),
227
+ "false_positives": int(fp),
228
+ "false_negatives": int(fn),
229
+
230
+ # Volume metrics
231
+ "prediction_volume_mm3": pred_metrics["total_volume_mm3"],
232
+ "ground_truth_volume_mm3": gt_metrics["total_volume_mm3"],
233
+ "volume_difference_mm3": round(
234
+ pred_metrics["total_volume_mm3"] - gt_metrics["total_volume_mm3"], 2
235
+ ),
236
+
237
+ # Metadata
238
+ "mask_shape": prediction.shape,
239
+ "spacing_mm": spacing,
240
+ }
241
+
242
+
243
  def format_metrics_for_display(metrics: Dict[str, Any]) -> Dict[str, Any]:
244
  """Format metrics dictionary for UI display.
245
 
seg_app/models/monai_autoseg.py CHANGED
@@ -344,6 +344,25 @@ def register_monai_models() -> None:
344
  ),
345
  )
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  # Additional models can be registered here
348
  # register_model(
349
  # model_id="monai-auto3dseg-spleen",
 
344
  ),
345
  )
346
 
347
+ # 3D U-Net for brain lesion segmentation (baseline model)
348
+ register_model(
349
+ model_id="unet3d-brain-tumor",
350
+ model_class=MONAIAuto3DSeg,
351
+ config=ModelConfig(
352
+ model_id="unet3d-brain-tumor",
353
+ hf_hub_path="your-org/brain-tumor-model", # Placeholder HF path
354
+ device="cuda" if torch.cuda.is_available() else "cpu",
355
+ preprocessing={
356
+ "target_spacing": (1.0, 1.0, 1.0),
357
+ "normalize": True,
358
+ },
359
+ postprocessing={
360
+ "threshold": 0.5,
361
+ "min_component_size": 50,
362
+ },
363
+ ),
364
+ )
365
+
366
  # Additional models can be registered here
367
  # register_model(
368
  # model_id="monai-auto3dseg-spleen",
seg_app/models/nnunet_wrapper.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nnUNet model wrapper for brain lesion segmentation.
3
+
4
+ Provides a standardized BaseModel interface for nnUNet v2 models,
5
+ supporting loading from local checkpoint directories.
6
+
7
+ nnUNet v2 Reference:
8
+ https://github.com/MIC-DKFZ/nnUNet
9
+
10
+ Expected directory structure:
11
+ nnUNet_results/
12
+ └── DatasetXXX_Name/
13
+ └── nnUNetTrainer__nnUNetPlans__3d_fullres/
14
+ ├── plans.json
15
+ ├── dataset.json
16
+ ├── fold_0/
17
+ │ └── checkpoint_final.pth
18
+ ├── fold_1/
19
+ │ └── checkpoint_final.pth
20
+ └── ...
21
+
22
+ Usage:
23
+ config = ModelConfig(
24
+ model_id="nnunet-brain-lesion",
25
+ local_path="/path/to/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres",
26
+ )
27
+ model = nnUNetWrapper(config)
28
+ model.load()
29
+ mask = model(volume, spacing=(1.0, 1.0, 1.0))
30
+ """
31
+
32
+ import logging
33
+ import os
34
+ import tempfile
35
+ from pathlib import Path
36
+ from typing import Any, Dict, List, Optional, Tuple, Union
37
+
38
+ import numpy as np
39
+ import torch
40
+
41
+ from seg_app.models.base import BaseModel, ModelConfig
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class nnUNetWrapper(BaseModel):
47
+ """nnUNet v2 model wrapper implementing the BaseModel interface.
48
+
49
+ This wrapper integrates nnUNet v2's inference API with the seg_app
50
+ orchestrator pattern. It supports:
51
+
52
+ - Loading from local checkpoint directory
53
+ - Single-fold or ensemble inference
54
+ - Automatic preprocessing (handled by nnUNet predictor)
55
+ - Configurable postprocessing
56
+
57
+ Note:
58
+ nnUNet handles its own preprocessing (resampling, normalization)
59
+ based on the plans.json, so preprocess() passes through the volume
60
+ and postprocess() handles the output conversion.
61
+
62
+ Attributes:
63
+ predictor: The nnUNet predictor instance (lazy-loaded)
64
+ use_folds: Which folds to use for inference (default: all available)
65
+ use_gaussian: Whether to use Gaussian weighting for patch aggregation
66
+ use_mirroring: Whether to use test-time augmentation with mirroring
67
+ """
68
+
69
+ def __init__(self, config: ModelConfig):
70
+ """Initialize nnUNet wrapper.
71
+
72
+ Args:
73
+ config: Model configuration with local_path pointing to the
74
+ trained nnUNet model folder (e.g., nnUNetTrainer__nnUNetPlans__3d_fullres)
75
+ """
76
+ super().__init__(config)
77
+
78
+ self.predictor = None
79
+
80
+ # Inference settings (can be overridden via config.preprocessing)
81
+ preprocessing = config.preprocessing or {}
82
+ self.use_folds: Union[str, List[int]] = preprocessing.get("use_folds", "all")
83
+ self.use_gaussian: bool = preprocessing.get("use_gaussian", True)
84
+ self.use_mirroring: bool = preprocessing.get("use_mirroring", True)
85
+ self.save_probabilities: bool = preprocessing.get("save_probabilities", False)
86
+
87
+ # Postprocessing settings
88
+ postprocessing = config.postprocessing or {}
89
+ self.threshold: float = postprocessing.get("threshold", 0.5)
90
+
91
+ # Cache for input metadata (needed if nnUNet changes spacing)
92
+ self._original_spacing: Optional[Tuple[float, float, float]] = None
93
+ self._original_shape: Optional[Tuple[int, int, int]] = None
94
+
95
+ def load(self) -> None:
96
+ """Load nnUNet model from local checkpoint directory.
97
+
98
+ This initializes the nnUNetPredictor and loads model weights.
99
+ The predictor handles all internal setup (network architecture,
100
+ preprocessing pipeline, etc.) based on plans.json.
101
+
102
+ Raises:
103
+ FileNotFoundError: If local_path doesn't exist
104
+ ImportError: If nnunetv2 is not installed
105
+ RuntimeError: If model loading fails
106
+ """
107
+ if self._is_loaded:
108
+ logger.info(f"Model {self.config.model_id} already loaded")
109
+ return
110
+
111
+ if self.config.local_path is None:
112
+ raise RuntimeError(
113
+ f"nnUNet requires a local path to the trained model folder. "
114
+ f"Set local_path in ModelConfig to the nnUNetTrainer__nnUNetPlans__3d_fullres directory."
115
+ )
116
+
117
+ model_folder = Path(self.config.local_path)
118
+ if not model_folder.exists():
119
+ raise FileNotFoundError(
120
+ f"nnUNet model folder not found: {model_folder}\n"
121
+ f"Expected structure: nnUNet_results/DatasetXXX_Name/nnUNetTrainer__nnUNetPlans__3d_fullres/"
122
+ )
123
+
124
+ # Check for required files
125
+ plans_file = model_folder / "plans.json"
126
+ if not plans_file.exists():
127
+ # Try alternative location
128
+ plans_file = model_folder / "nnUNetPlans.json"
129
+ if not plans_file.exists():
130
+ raise FileNotFoundError(
131
+ f"plans.json not found in {model_folder}. "
132
+ f"Ensure this is a valid nnUNet training output folder."
133
+ )
134
+
135
+ try:
136
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
137
+ except ImportError:
138
+ raise ImportError(
139
+ "nnunetv2 is not installed. Install it with:\n"
140
+ "pip install nnunetv2\n"
141
+ "Or see: https://github.com/MIC-DKFZ/nnUNet"
142
+ )
143
+
144
+ logger.info(f"Loading nnUNet model from: {model_folder}")
145
+
146
+ # Initialize predictor
147
+ self.predictor = nnUNetPredictor(
148
+ tile_step_size=0.5,
149
+ use_gaussian=self.use_gaussian,
150
+ use_mirroring=self.use_mirroring,
151
+ perform_everything_on_device=True,
152
+ device=torch.device(self.device),
153
+ verbose=False,
154
+ verbose_preprocessing=False,
155
+ allow_tqdm=False,
156
+ )
157
+
158
+ # Determine which folds to use
159
+ if self.use_folds == "all":
160
+ folds = self._find_available_folds(model_folder)
161
+ else:
162
+ folds = self.use_folds
163
+
164
+ if not folds:
165
+ # Default to fold 0 if no specific folds found
166
+ folds = [0]
167
+ logger.warning(f"No fold directories found, defaulting to fold 0")
168
+
169
+ logger.info(f"Using folds: {folds}")
170
+
171
+ # Initialize from trained model folder
172
+ # This loads plans.json, network architecture, and weights
173
+ self.predictor.initialize_from_trained_model_folder(
174
+ model_training_output_dir=str(model_folder),
175
+ use_folds=folds,
176
+ checkpoint_name="checkpoint_final.pth", # or "checkpoint_best.pth"
177
+ )
178
+
179
+ self._is_loaded = True
180
+ logger.info(f"nnUNet model loaded successfully on {self.device}")
181
+
182
+ def _find_available_folds(self, model_folder: Path) -> List[int]:
183
+ """Find available fold directories in the model folder.
184
+
185
+ Args:
186
+ model_folder: Path to nnUNet training output
187
+
188
+ Returns:
189
+ List of fold numbers (e.g., [0, 1, 2, 3, 4])
190
+ """
191
+ folds = []
192
+ for item in model_folder.iterdir():
193
+ if item.is_dir() and item.name.startswith("fold_"):
194
+ try:
195
+ fold_num = int(item.name.split("_")[1])
196
+ # Check if checkpoint exists
197
+ if (item / "checkpoint_final.pth").exists():
198
+ folds.append(fold_num)
199
+ elif (item / "checkpoint_best.pth").exists():
200
+ folds.append(fold_num)
201
+ except (ValueError, IndexError):
202
+ continue
203
+ return sorted(folds)
204
+
205
+ def preprocess(
206
+ self,
207
+ volume: np.ndarray,
208
+ spacing: Optional[Tuple[float, float, float]] = None,
209
+ ) -> torch.Tensor:
210
+ """Prepare volume for nnUNet inference.
211
+
212
+ Note: nnUNet's predictor handles its own preprocessing internally,
213
+ so this method mainly stores metadata and prepares the volume format.
214
+
215
+ Args:
216
+ volume: 3D numpy array with shape (D, H, W)
217
+ spacing: Voxel spacing in mm as (depth, height, width)
218
+
219
+ Returns:
220
+ Volume as tensor (nnUNet expects numpy, but we follow interface)
221
+ """
222
+ # Store original metadata for postprocessing
223
+ self._original_shape = volume.shape
224
+ self._original_spacing = spacing
225
+
226
+ # nnUNet expects (C, D, H, W) where C is modality/channel
227
+ # For single-modality MRI, add channel dimension
228
+ if volume.ndim == 3:
229
+ volume = volume[np.newaxis, ...] # (1, D, H, W)
230
+
231
+ # Return as tensor to match interface (but nnUNet will work with numpy)
232
+ tensor = torch.from_numpy(volume.astype(np.float32))
233
+ return tensor.to(self.device)
234
+
235
+ def predict(
236
+ self,
237
+ tensor: torch.Tensor,
238
+ prompts: Optional[Any] = None,
239
+ ) -> torch.Tensor:
240
+ """Run nnUNet inference.
241
+
242
+ Args:
243
+ tensor: Preprocessed volume tensor with shape (1, D, H, W)
244
+ prompts: Not used for nnUNet (ignored with warning)
245
+
246
+ Returns:
247
+ Segmentation output tensor
248
+
249
+ Raises:
250
+ RuntimeError: If model not loaded
251
+ """
252
+ if not self._is_loaded or self.predictor is None:
253
+ raise RuntimeError("Model not loaded. Call load() first.")
254
+
255
+ if prompts is not None:
256
+ logger.warning(
257
+ "nnUNet does not support interactive prompts. Ignoring provided prompts. "
258
+ "Use Medical SAM 3D for prompt-based refinement."
259
+ )
260
+
261
+ # Convert back to numpy for nnUNet (it works with numpy internally)
262
+ volume_np = tensor.cpu().numpy() # (1, D, H, W)
263
+
264
+ # Prepare properties dict for nnUNet
265
+ # nnUNet expects spacing in (x, y, z) = (W, H, D) order
266
+ if self._original_spacing is not None:
267
+ # Input spacing is (D, H, W), nnUNet wants (W, H, D)
268
+ spacing_xyz = self._original_spacing[::-1] # Reverse order
269
+ else:
270
+ # Default to 1mm isotropic if not provided
271
+ spacing_xyz = (1.0, 1.0, 1.0)
272
+ logger.warning("No spacing provided, using 1mm isotropic")
273
+
274
+ properties = {
275
+ "spacing": list(spacing_xyz),
276
+ }
277
+
278
+ # Run prediction using nnUNet's predict_single_npy_array
279
+ # This handles all internal preprocessing, sliding window, and postprocessing
280
+ with torch.inference_mode():
281
+ segmentation = self.predictor.predict_single_npy_array(
282
+ input_image=volume_np,
283
+ image_properties=properties,
284
+ segmentation_previous_stage=None,
285
+ output_file_truncated=None,
286
+ save_or_return_probabilities=self.save_probabilities,
287
+ )
288
+
289
+ # Convert result to tensor
290
+ if isinstance(segmentation, tuple):
291
+ # If save_probabilities=True, returns (segmentation, probabilities)
292
+ segmentation = segmentation[0]
293
+
294
+ return torch.from_numpy(segmentation.astype(np.float32))
295
+
296
+ def postprocess(
297
+ self,
298
+ tensor: torch.Tensor,
299
+ original_shape: Tuple[int, int, int],
300
+ ) -> np.ndarray:
301
+ """Postprocess nnUNet output.
302
+
303
+ nnUNet typically outputs integer labels, so minimal postprocessing
304
+ is needed. This method ensures the output is a binary mask and
305
+ optionally resamples to the original shape if nnUNet changed it.
306
+
307
+ Args:
308
+ tensor: Model output tensor
309
+ original_shape: Original volume shape (D, H, W)
310
+
311
+ Returns:
312
+ Binary segmentation mask as numpy array
313
+ """
314
+ mask = tensor.cpu().numpy()
315
+
316
+ # nnUNet outputs integer labels (0 = background, 1+ = foreground classes)
317
+ # For binary segmentation, threshold at 0.5 if probabilities,
318
+ # or take any non-zero value as foreground
319
+ if mask.dtype == np.float32 or mask.dtype == np.float64:
320
+ # Probability output
321
+ mask = (mask > self.threshold).astype(np.uint8)
322
+ else:
323
+ # Integer label output - binarize (any label > 0 is foreground)
324
+ mask = (mask > 0).astype(np.uint8)
325
+
326
+ # Ensure output matches original shape
327
+ if mask.shape != original_shape:
328
+ # nnUNet may change shape due to resampling
329
+ # Use scipy to resample back
330
+ from scipy.ndimage import zoom
331
+
332
+ zoom_factors = tuple(
333
+ o / m for o, m in zip(original_shape, mask.shape)
334
+ )
335
+ mask = zoom(mask.astype(np.float32), zoom_factors, order=0) > 0.5
336
+ mask = mask.astype(np.uint8)
337
+
338
+ return mask
339
+
340
+
341
+ def register_nnunet() -> None:
342
+ """Register nnUNet model in the model registry.
343
+
344
+ This function is called by the model registry during lazy initialization.
345
+ It registers the nnUNet wrapper with the configured checkpoint path.
346
+ """
347
+ from seg_app.inference.model_registry import register_model
348
+ from seg_app.config.settings import NNUNET_CONFIG
349
+
350
+ # Only register if nnUNet checkpoint path is configured
351
+ if NNUNET_CONFIG.checkpoint_path is None:
352
+ logger.warning(
353
+ "nnUNet checkpoint path not configured. "
354
+ "Set NNUNET_CONFIG.checkpoint_path in settings.py to enable nnUNet."
355
+ )
356
+ return
357
+
358
+ config = ModelConfig(
359
+ model_id="nnunet-brain-lesion",
360
+ local_path=NNUNET_CONFIG.checkpoint_path,
361
+ device="cuda" if torch.cuda.is_available() else "cpu",
362
+ preprocessing={
363
+ "use_folds": NNUNET_CONFIG.use_folds,
364
+ "use_gaussian": True,
365
+ "use_mirroring": NNUNET_CONFIG.use_mirroring,
366
+ },
367
+ postprocessing={
368
+ "threshold": 0.5,
369
+ },
370
+ )
371
+
372
+ register_model("nnunet-brain-lesion", nnUNetWrapper, config)
373
+ logger.info(f"Registered nnUNet model: nnunet-brain-lesion")
374
+
375
+
376
+ # Convenience function for direct testing
377
+ def create_nnunet_model(checkpoint_path: str, device: str = "cuda") -> nnUNetWrapper:
378
+ """Create an nnUNet model instance for testing.
379
+
380
+ Args:
381
+ checkpoint_path: Path to nnUNet training output folder
382
+ device: Device to load model on
383
+
384
+ Returns:
385
+ Loaded nnUNet model ready for inference
386
+
387
+ Example:
388
+ >>> model = create_nnunet_model("/path/to/nnUNet_results/Dataset001/nnUNetTrainer__nnUNetPlans__3d_fullres")
389
+ >>> model.load()
390
+ >>> mask = model(volume, spacing=(1.0, 1.0, 1.0))
391
+ """
392
+ config = ModelConfig(
393
+ model_id="nnunet-test",
394
+ local_path=checkpoint_path,
395
+ device=device,
396
+ )
397
+ model = nnUNetWrapper(config)
398
+ model.load()
399
+ return model
seg_app/ui_slicer/api-client.js CHANGED
@@ -192,6 +192,64 @@ class ApiClient {
192
  }
193
  return await response.json();
194
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  }
196
 
197
  // Singleton instance
 
192
  }
193
  return await response.json();
194
  }
195
+
196
+ /**
197
+ * Upload ground truth mask for Dice score computation
198
+ * @param {string} volumeId - The volume ID
199
+ * @param {File} file - The ground truth NIfTI file
200
+ * @returns {Promise<{volume_id: string, gt_shape: number[], message: string}>}
201
+ */
202
+ async uploadGroundTruth(volumeId, file) {
203
+ const formData = new FormData();
204
+ formData.append('file', file);
205
+
206
+ const response = await fetch(`${this.baseUrl}/ground-truth/${volumeId}`, {
207
+ method: 'POST',
208
+ body: formData
209
+ });
210
+
211
+ if (!response.ok) {
212
+ const error = await response.json();
213
+ throw new Error(error.detail || `Ground truth upload failed: ${response.status}`);
214
+ }
215
+
216
+ return await response.json();
217
+ }
218
+
219
+ /**
220
+ * Get segmentation metrics including Dice score
221
+ * @param {string} volumeId - The volume ID
222
+ * @returns {Promise<{volume_id: string, has_ground_truth: boolean, voxel_count: number, volume_mm3: number, dice_score?: number, iou_score?: number, sensitivity?: number, precision?: number}>}
223
+ */
224
+ async getMetrics(volumeId) {
225
+ const response = await fetch(`${this.baseUrl}/metrics/${volumeId}`);
226
+
227
+ if (!response.ok) {
228
+ const error = await response.json();
229
+ throw new Error(error.detail || `Failed to get metrics: ${response.status}`);
230
+ }
231
+
232
+ return await response.json();
233
+ }
234
+
235
+ /**
236
+ * Get ground truth mask data as Uint8Array (raw bytes)
237
+ * @param {string} volumeId - The volume ID
238
+ * @returns {Promise<ArrayBuffer|null>}
239
+ */
240
+ async getGroundTruthData(volumeId) {
241
+ const response = await fetch(`${this.baseUrl}/ground-truth/${volumeId}/data`);
242
+
243
+ if (!response.ok) {
244
+ if (response.status === 404) {
245
+ return null; // No GT available
246
+ }
247
+ const error = await response.json();
248
+ throw new Error(error.detail || `Failed to get ground truth: ${response.status}`);
249
+ }
250
+
251
+ return await response.arrayBuffer();
252
+ }
253
  }
254
 
255
  // Singleton instance
seg_app/ui_slicer/app.js CHANGED
@@ -64,7 +64,11 @@ async function loadModels() {
64
  const select = document.getElementById('model-select');
65
  select.innerHTML = '';
66
 
67
- models.forEach((model, index) => {
 
 
 
 
68
  const option = document.createElement('option');
69
  // Handle both 'id' and 'model_id' from API
70
  const modelId = model.id || model.model_id;
@@ -72,12 +76,20 @@ async function loadModels() {
72
  option.textContent = model.display_name;
73
  select.appendChild(option);
74
 
75
- // Select first model by default
76
- if (index === 0) {
77
  appState.set('selectedModel', modelId);
 
 
78
  }
79
  });
80
 
 
 
 
 
 
 
81
  select.disabled = false;
82
  updateModelInfo();
83
  } catch (error) {
@@ -104,6 +116,9 @@ async function handleVolumeUpload(file) {
104
  showLoading('Uploading volume...');
105
 
106
  try {
 
 
 
107
  // Upload to backend
108
  const response = await apiClient.uploadVolume(file);
109
 
@@ -194,11 +209,17 @@ async function runSegmentation() {
194
  const shape = appState.get('volumeShape');
195
  const spacing = appState.get('volumeSpacing');
196
  mprViewer.loadMask(maskData, shape, spacing);
 
 
 
197
  }
198
 
199
  hideLoading();
200
  showToast('Success', `Segmentation complete (${response.model_id})`, 'success');
201
 
 
 
 
202
  } catch (error) {
203
  hideLoading();
204
  console.error('Segmentation failed:', error);
@@ -236,11 +257,17 @@ async function runRefinement() {
236
  const shape = appState.get('volumeShape');
237
  const spacing = appState.get('volumeSpacing');
238
  mprViewer.loadMask(maskData, shape, spacing);
 
 
 
239
  }
240
 
241
  hideLoading();
242
  showToast('Success', `Refinement complete (${response.model_id})`, 'success');
243
 
 
 
 
244
  } catch (error) {
245
  hideLoading();
246
  console.error('Refinement failed:', error);
@@ -248,6 +275,324 @@ async function runRefinement() {
248
  }
249
  }
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  // ============ Tool Management ============
252
 
253
  function enableTools(enabled) {
@@ -256,6 +601,10 @@ function enableTools(enabled) {
256
  document.getElementById('tool-clear').disabled = !enabled;
257
  document.getElementById('run-segment').disabled = !enabled;
258
  document.getElementById('run-refine').disabled = !enabled;
 
 
 
 
259
  }
260
 
261
  function setActiveTool(tool) {
@@ -420,6 +769,25 @@ function setupEventListeners() {
420
  document.getElementById('run-segment').addEventListener('click', runSegmentation);
421
  document.getElementById('run-refine').addEventListener('click', runRefinement);
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  // Display sliders
424
  document.getElementById('window-slider').addEventListener('input', updateWindowValue);
425
  document.getElementById('level-slider').addEventListener('input', updateLevelValue);
@@ -447,6 +815,10 @@ async function initialize() {
447
  mprViewer.initialize();
448
  console.log('MPR viewer initialized');
449
 
 
 
 
 
450
  // Check backend connection
451
  const connected = await checkConnection();
452
  console.log('Backend connection:', connected);
 
64
  const select = document.getElementById('model-select');
65
  select.innerHTML = '';
66
 
67
+ // Find the 3D U-Net baseline model to set as default
68
+ const defaultModelId = 'unet3d-brain-tumor';
69
+ let defaultFound = false;
70
+
71
+ models.forEach((model) => {
72
  const option = document.createElement('option');
73
  // Handle both 'id' and 'model_id' from API
74
  const modelId = model.id || model.model_id;
 
76
  option.textContent = model.display_name;
77
  select.appendChild(option);
78
 
79
+ // Select 3D U-Net (Baseline) as default
80
+ if (modelId === defaultModelId) {
81
  appState.set('selectedModel', modelId);
82
+ select.value = modelId;
83
+ defaultFound = true;
84
  }
85
  });
86
 
87
+ // Fallback to first model if default not found
88
+ if (!defaultFound && models.length > 0) {
89
+ const firstModelId = models[0].id || models[0].model_id;
90
+ appState.set('selectedModel', firstModelId);
91
+ }
92
+
93
  select.disabled = false;
94
  updateModelInfo();
95
  } catch (error) {
 
116
  showLoading('Uploading volume...');
117
 
118
  try {
119
+ // Reset previous state (GT, metrics, overlap viewer)
120
+ resetMaskAndMetrics();
121
+
122
  // Upload to backend
123
  const response = await apiClient.uploadVolume(file);
124
 
 
209
  const shape = appState.get('volumeShape');
210
  const spacing = appState.get('volumeSpacing');
211
  mprViewer.loadMask(maskData, shape, spacing);
212
+
213
+ // Update overlap visualization
214
+ renderOverlapCanvas();
215
  }
216
 
217
  hideLoading();
218
  showToast('Success', `Segmentation complete (${response.model_id})`, 'success');
219
 
220
+ // Fetch and display metrics
221
+ await updateMetricsDisplay();
222
+
223
  } catch (error) {
224
  hideLoading();
225
  console.error('Segmentation failed:', error);
 
257
  const shape = appState.get('volumeShape');
258
  const spacing = appState.get('volumeSpacing');
259
  mprViewer.loadMask(maskData, shape, spacing);
260
+
261
+ // Update overlap visualization
262
+ renderOverlapCanvas();
263
  }
264
 
265
  hideLoading();
266
  showToast('Success', `Refinement complete (${response.model_id})`, 'success');
267
 
268
+ // Fetch and display metrics
269
+ await updateMetricsDisplay();
270
+
271
  } catch (error) {
272
  hideLoading();
273
  console.error('Refinement failed:', error);
 
275
  }
276
  }
277
 
278
+ // ============ Ground Truth & Metrics ============
279
+
280
+ async function handleGroundTruthUpload(file) {
281
+ const volumeId = appState.get('volumeId');
282
+ if (!volumeId) {
283
+ showToast('Warning', 'Please upload a volume first', 'info');
284
+ return;
285
+ }
286
+
287
+ showLoading('Uploading ground truth...');
288
+
289
+ try {
290
+ const response = await apiClient.uploadGroundTruth(volumeId, file);
291
+
292
+ // Update GT status
293
+ document.getElementById('gt-status').textContent = 'Loaded ✓';
294
+ document.getElementById('gt-status').style.color = 'var(--success-color)';
295
+
296
+ // Fetch the GT mask data for overlap visualization
297
+ const gtArrayBuffer = await apiClient.getGroundTruthData(volumeId);
298
+ const gtMaskData = new Uint8Array(gtArrayBuffer);
299
+ appState.setGroundTruth(gtMaskData);
300
+
301
+ // Render overlap canvas
302
+ renderOverlapCanvas();
303
+
304
+ hideLoading();
305
+ showToast('Success', `Ground truth loaded: ${response.gt_shape.join('×')}`, 'success');
306
+
307
+ // If we have a mask, update metrics to show Dice
308
+ if (appState.get('maskData')) {
309
+ await updateMetricsDisplay();
310
+ }
311
+
312
+ } catch (error) {
313
+ hideLoading();
314
+ console.error('Ground truth upload failed:', error);
315
+ showToast('Error', error.message, 'error');
316
+ }
317
+ }
318
+
319
+ async function updateMetricsDisplay() {
320
+ const volumeId = appState.get('volumeId');
321
+ if (!volumeId) return;
322
+
323
+ try {
324
+ const metrics = await apiClient.getMetrics(volumeId);
325
+
326
+ // Update volume display
327
+ const volumeEl = document.getElementById('mask-volume');
328
+ if (volumeEl) {
329
+ volumeEl.textContent = `${metrics.volume_ml.toFixed(2)} mL`;
330
+ }
331
+
332
+ // Update Dice metrics if ground truth is available
333
+ if (metrics.has_ground_truth) {
334
+ document.getElementById('gt-status').textContent = 'Loaded ✓';
335
+ document.getElementById('gt-status').style.color = 'var(--success-color)';
336
+
337
+ // Dice score with color coding
338
+ const diceEl = document.getElementById('dice-score');
339
+ if (diceEl && metrics.dice_score !== null) {
340
+ const dice = metrics.dice_score;
341
+ diceEl.textContent = dice.toFixed(4);
342
+ // Color code: green > 0.7, yellow 0.5-0.7, red < 0.5
343
+ if (dice >= 0.7) {
344
+ diceEl.style.color = 'var(--success-color)';
345
+ } else if (dice >= 0.5) {
346
+ diceEl.style.color = 'var(--warning-color)';
347
+ } else {
348
+ diceEl.style.color = 'var(--error-color)';
349
+ }
350
+ }
351
+
352
+ const iouEl = document.getElementById('iou-score');
353
+ if (iouEl && metrics.iou_score !== null) {
354
+ iouEl.textContent = metrics.iou_score.toFixed(4);
355
+ }
356
+
357
+ const sensEl = document.getElementById('sensitivity');
358
+ if (sensEl && metrics.sensitivity !== null) {
359
+ sensEl.textContent = metrics.sensitivity.toFixed(4);
360
+ }
361
+
362
+ const precEl = document.getElementById('precision-score');
363
+ if (precEl && metrics.precision !== null) {
364
+ precEl.textContent = metrics.precision.toFixed(4);
365
+ }
366
+ } else {
367
+ // No ground truth
368
+ document.getElementById('dice-score').textContent = '-';
369
+ document.getElementById('iou-score').textContent = '-';
370
+ document.getElementById('sensitivity').textContent = '-';
371
+ document.getElementById('precision-score').textContent = '-';
372
+ }
373
+
374
+ } catch (error) {
375
+ console.error('Failed to fetch metrics:', error);
376
+ }
377
+ }
378
+
379
+ /**
380
+ * Reset mask, ground truth, metrics display, and overlap viewer
381
+ * Called when uploading a new volume
382
+ */
383
+ function resetMaskAndMetrics() {
384
+ // Clear state
385
+ appState.clearMask();
386
+ appState.clearGroundTruth();
387
+
388
+ // Reset GT status indicator
389
+ const gtStatus = document.getElementById('gt-status');
390
+ if (gtStatus) {
391
+ gtStatus.textContent = 'Not loaded';
392
+ gtStatus.style.color = '';
393
+ }
394
+
395
+ // Reset metrics display
396
+ const volumeEl = document.getElementById('mask-volume');
397
+ if (volumeEl) volumeEl.textContent = '-';
398
+
399
+ const diceEl = document.getElementById('dice-score');
400
+ if (diceEl) {
401
+ diceEl.textContent = '-';
402
+ diceEl.style.color = '';
403
+ }
404
+
405
+ const iouEl = document.getElementById('iou-score');
406
+ if (iouEl) iouEl.textContent = '-';
407
+
408
+ const sensEl = document.getElementById('sensitivity');
409
+ if (sensEl) sensEl.textContent = '-';
410
+
411
+ const precEl = document.getElementById('precision-score');
412
+ if (precEl) precEl.textContent = '-';
413
+
414
+ // Reset overlap viewer
415
+ const placeholder = document.getElementById('overlap-placeholder');
416
+ if (placeholder) placeholder.style.display = '';
417
+
418
+ const sliceIndicator = document.getElementById('overlap-slice-indicator');
419
+ if (sliceIndicator) sliceIndicator.style.display = 'none';
420
+
421
+ const sliceLabel = document.getElementById('overlap-slice-num');
422
+ if (sliceLabel) sliceLabel.textContent = '-';
423
+
424
+ // Clear overlap canvas
425
+ const canvas = document.getElementById('overlap-canvas');
426
+ if (canvas) {
427
+ const ctx = canvas.getContext('2d');
428
+ ctx.fillStyle = '#1a1a2e';
429
+ ctx.fillRect(0, 0, canvas.width, canvas.height);
430
+ }
431
+
432
+ // Clear mask from viewer
433
+ mprViewer.clearMask();
434
+ }
435
+
436
+ // ============ Overlap Visualization ============
437
+
438
+ /**
439
+ * Render overlap canvas showing TP/FP/FN comparison between prediction and ground truth
440
+ * Colors: TP=green, FP=red, FN=blue, Pred-only (no GT)=orange
441
+ */
442
+ function renderOverlapCanvas() {
443
+ const canvas = document.getElementById('overlap-canvas');
444
+ const container = document.getElementById('overlap-canvas-container');
445
+ if (!canvas || !container) return;
446
+
447
+ const ctx = canvas.getContext('2d');
448
+ const maskData = appState.get('maskData');
449
+ const gtMaskData = appState.get('gtMaskData');
450
+ const volumeShape = appState.get('volumeShape');
451
+
452
+ // Get container size and set canvas to match
453
+ const rect = container.getBoundingClientRect();
454
+ const width = Math.floor(rect.width);
455
+ const height = Math.floor(rect.height);
456
+
457
+ if (width === 0 || height === 0) return;
458
+
459
+ canvas.width = width;
460
+ canvas.height = height;
461
+
462
+ // Clear canvas
463
+ ctx.fillStyle = '#1a1a2e';
464
+ ctx.fillRect(0, 0, width, height);
465
+
466
+ // Check if we have any data to display
467
+ if (!volumeShape || !maskData) {
468
+ ctx.fillStyle = '#666';
469
+ ctx.font = '12px sans-serif';
470
+ ctx.textAlign = 'center';
471
+ ctx.fillText('No segmentation loaded', width / 2, height / 2);
472
+ return;
473
+ }
474
+
475
+ const [D, H, W] = volumeShape; // D=depth, H=height, W=width
476
+ const sliceIndex = appState.get('overlapSliceIndex');
477
+
478
+ // Create image data at native resolution
479
+ const imageData = ctx.createImageData(D, H);
480
+ const pixels = imageData.data;
481
+
482
+ // Color definitions (RGBA)
483
+ const colors = {
484
+ TP: [76, 175, 80, 220], // Green - True Positive
485
+ FP: [244, 67, 54, 220], // Red - False Positive
486
+ FN: [33, 150, 243, 220], // Blue - False Negative
487
+ PRED_ONLY: [255, 152, 0, 220], // Orange - Prediction only (no GT available)
488
+ BG: [26, 26, 46, 255] // Dark background
489
+ };
490
+
491
+ // Iterate through slice pixels
492
+ for (let d = 0; d < D; d++) {
493
+ for (let h = 0; h < H; h++) {
494
+ // Volume index: (d, h, w) where w is the slice index
495
+ const volIdx = d * H * W + h * W + sliceIndex;
496
+
497
+ const pred = maskData[volIdx] > 0;
498
+ const gt = gtMaskData ? gtMaskData[volIdx] > 0 : null;
499
+
500
+ // Canvas pixel index: transposed for correct orientation
501
+ const canvasIdx = (h * D + d) * 4;
502
+
503
+ let color;
504
+ if (gt === null) {
505
+ // No ground truth available
506
+ color = pred ? colors.PRED_ONLY : colors.BG;
507
+ } else {
508
+ // Ground truth available - show TP/FP/FN
509
+ if (pred && gt) {
510
+ color = colors.TP; // True Positive
511
+ } else if (pred && !gt) {
512
+ color = colors.FP; // False Positive
513
+ } else if (!pred && gt) {
514
+ color = colors.FN; // False Negative
515
+ } else {
516
+ color = colors.BG; // True Negative (background)
517
+ }
518
+ }
519
+
520
+ pixels[canvasIdx] = color[0]; // R
521
+ pixels[canvasIdx + 1] = color[1]; // G
522
+ pixels[canvasIdx + 2] = color[2]; // B
523
+ pixels[canvasIdx + 3] = color[3]; // A
524
+ }
525
+ }
526
+
527
+ // Create temp canvas for scaling
528
+ const tempCanvas = document.createElement('canvas');
529
+ tempCanvas.width = D;
530
+ tempCanvas.height = H;
531
+ tempCanvas.getContext('2d').putImageData(imageData, 0, 0);
532
+
533
+ // Scale to fit canvas while maintaining aspect ratio
534
+ const aspect = D / H;
535
+ const canvasAspect = width / height;
536
+
537
+ let displayW, displayH, offsetX, offsetY;
538
+ if (canvasAspect > aspect) {
539
+ displayH = height;
540
+ displayW = displayH * aspect;
541
+ offsetX = (width - displayW) / 2;
542
+ offsetY = 0;
543
+ } else {
544
+ displayW = width;
545
+ displayH = displayW / aspect;
546
+ offsetX = 0;
547
+ offsetY = (height - displayH) / 2;
548
+ }
549
+
550
+ ctx.imageSmoothingEnabled = false; // Nearest neighbor for mask
551
+ ctx.drawImage(tempCanvas, offsetX, offsetY, displayW, displayH);
552
+
553
+ // Update slice indicator
554
+ const sliceIndicator = document.getElementById('overlap-slice-indicator');
555
+ const sliceLabel = document.getElementById('overlap-slice-num');
556
+ if (sliceLabel) {
557
+ sliceLabel.textContent = `${sliceIndex + 1} / ${W}`;
558
+ }
559
+ if (sliceIndicator) {
560
+ sliceIndicator.style.display = 'block';
561
+ }
562
+
563
+ // Hide placeholder
564
+ const placeholder = document.getElementById('overlap-placeholder');
565
+ if (placeholder) placeholder.style.display = 'none';
566
+ }
567
+
568
+ /**
569
+ * Initialize overlap viewer with scroll support
570
+ */
571
+ function initOverlapViewer() {
572
+ const container = document.getElementById('overlap-canvas-container');
573
+ if (!container) return;
574
+
575
+ // Mouse wheel scrolling
576
+ container.addEventListener('wheel', (e) => {
577
+ e.preventDefault();
578
+ const maskData = appState.get('maskData');
579
+ if (!maskData) return;
580
+
581
+ const delta = e.deltaY > 0 ? 1 : -1;
582
+ const currentSlice = appState.get('overlapSliceIndex');
583
+ appState.setOverlapSlice(currentSlice + delta);
584
+ renderOverlapCanvas();
585
+ }, { passive: false });
586
+
587
+ // Resize observer for responsive canvas
588
+ const resizeObserver = new ResizeObserver(() => {
589
+ if (appState.get('maskData')) {
590
+ renderOverlapCanvas();
591
+ }
592
+ });
593
+ resizeObserver.observe(container);
594
+ }
595
+
596
  // ============ Tool Management ============
597
 
598
  function enableTools(enabled) {
 
601
  document.getElementById('tool-clear').disabled = !enabled;
602
  document.getElementById('run-segment').disabled = !enabled;
603
  document.getElementById('run-refine').disabled = !enabled;
604
+
605
+ // Enable ground truth upload when volume is loaded
606
+ const gtBtn = document.getElementById('upload-gt-btn');
607
+ if (gtBtn) gtBtn.disabled = !enabled;
608
  }
609
 
610
  function setActiveTool(tool) {
 
769
  document.getElementById('run-segment').addEventListener('click', runSegmentation);
770
  document.getElementById('run-refine').addEventListener('click', runRefinement);
771
 
772
+ // Ground truth upload
773
+ const gtUploadBtn = document.getElementById('upload-gt-btn');
774
+ const gtFileInput = document.getElementById('gt-upload-input');
775
+
776
+ if (gtUploadBtn && gtFileInput) {
777
+ gtUploadBtn.addEventListener('click', () => {
778
+ gtFileInput.click();
779
+ });
780
+
781
+ gtFileInput.addEventListener('change', (e) => {
782
+ const file = e.target.files[0];
783
+ if (file) {
784
+ handleGroundTruthUpload(file);
785
+ }
786
+ // Reset input so same file can be selected again
787
+ e.target.value = '';
788
+ });
789
+ }
790
+
791
  // Display sliders
792
  document.getElementById('window-slider').addEventListener('input', updateWindowValue);
793
  document.getElementById('level-slider').addEventListener('input', updateLevelValue);
 
815
  mprViewer.initialize();
816
  console.log('MPR viewer initialized');
817
 
818
+ // Initialize overlap viewer (scroll support)
819
+ initOverlapViewer();
820
+ console.log('Overlap viewer initialized');
821
+
822
  // Check backend connection
823
  const connected = await checkConnection();
824
  console.log('Backend connection:', connected);
seg_app/ui_slicer/index.html CHANGED
@@ -137,27 +137,77 @@
137
  <div class="viewer-header">
138
  <span class="view-label info">INFO</span>
139
  </div>
140
- <div class="info-panel" id="info-view">
141
- <div class="crosshair-info">
142
- <h4>📍 Crosshair Position</h4>
143
- <p>X: <span id="pos-x">-</span></p>
144
- <p>Y: <span id="pos-y">-</span></p>
145
- <p>Z: <span id="pos-z">-</span></p>
146
- <p>Value: <span id="pos-value">-</span></p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  </div>
148
- <div class="seg-info">
149
- <h4>📊 Segmentation</h4>
150
- <p>Mask loaded: <span id="mask-status">No</span></p>
151
- <p>Voxels: <span id="mask-voxels">-</span></p>
152
- </div>
153
- <div class="keyboard-shortcuts">
154
- <h4>⌨️ Shortcuts</h4>
155
- <p><kbd>Scroll</kbd> Change slice</p>
156
- <p><kbd>Click</kbd> Add prompt</p>
157
- <p><kbd>P</kbd> Positive mode</p>
158
- <p><kbd>N</kbd> Negative mode</p>
159
- <p><kbd>C</kbd> Clear prompts</p>
160
- <p><kbd>R</kbd> Run segmentation</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  </div>
162
  </div>
163
  </div>
 
137
  <div class="viewer-header">
138
  <span class="view-label info">INFO</span>
139
  </div>
140
+ <div class="info-panel-split" id="info-view">
141
+ <!-- Left Column: Metrics & Info -->
142
+ <div class="info-left">
143
+ <div class="crosshair-info">
144
+ <h4>📍 Crosshair Position</h4>
145
+ <p>X: <span id="pos-x">-</span></p>
146
+ <p>Y: <span id="pos-y">-</span></p>
147
+ <p>Z: <span id="pos-z">-</span></p>
148
+ <p>Value: <span id="pos-value">-</span></p>
149
+ </div>
150
+ <div class="seg-info">
151
+ <h4>📊 Segmentation</h4>
152
+ <p>Mask loaded: <span id="mask-status">No</span></p>
153
+ <p>Voxels: <span id="mask-voxels">-</span></p>
154
+ <p>Volume: <span id="mask-volume">-</span></p>
155
+ </div>
156
+ <div class="dice-info">
157
+ <h4>🎯 Dice Metrics</h4>
158
+ <p>Ground Truth: <span id="gt-status">Not loaded</span></p>
159
+ <p>Dice Score: <span id="dice-score">-</span></p>
160
+ <p>IoU Score: <span id="iou-score">-</span></p>
161
+ <p>Sensitivity: <span id="sensitivity">-</span></p>
162
+ <p>Precision: <span id="precision-score">-</span></p>
163
+ <input type="file" id="gt-upload-input" accept=".nii,.nii.gz" style="display: none;">
164
+ <button id="upload-gt-btn" class="btn btn-secondary btn-small" disabled>
165
+ 📤 Upload Ground Truth
166
+ </button>
167
+ </div>
168
+ <div class="keyboard-shortcuts">
169
+ <h4>⌨️ Shortcuts</h4>
170
+ <p><kbd>Scroll</kbd> Change slice</p>
171
+ <p><kbd>Click</kbd> Add prompt</p>
172
+ <p><kbd>P</kbd> Positive mode</p>
173
+ <p><kbd>N</kbd> Negative mode</p>
174
+ <p><kbd>C</kbd> Clear prompts</p>
175
+ <p><kbd>R</kbd> Run segmentation</p>
176
+ </div>
177
  </div>
178
+ <!-- Right Column: Overlap Viewer -->
179
+ <div class="info-right">
180
+ <div class="overlap-viewer">
181
+ <h4>🔍 Mask Comparison</h4>
182
+ <div class="overlap-canvas-container" id="overlap-canvas-container">
183
+ <canvas id="overlap-canvas"></canvas>
184
+ <div class="slice-indicator" id="overlap-slice-indicator">
185
+ <span id="overlap-slice-num">-</span>
186
+ </div>
187
+ <p class="overlap-placeholder" id="overlap-placeholder">
188
+ Run segmentation to view mask<br>
189
+ <small>Scroll to change slice</small>
190
+ </p>
191
+ </div>
192
+ <div class="overlap-legend">
193
+ <div class="legend-item">
194
+ <span class="legend-color tp"></span>
195
+ <span>TP (Both agree)</span>
196
+ </div>
197
+ <div class="legend-item">
198
+ <span class="legend-color fp"></span>
199
+ <span>FP (Pred only)</span>
200
+ </div>
201
+ <div class="legend-item">
202
+ <span class="legend-color fn"></span>
203
+ <span>FN (GT only)</span>
204
+ </div>
205
+ <div class="legend-item">
206
+ <span class="legend-color pred-only"></span>
207
+ <span>Prediction</span>
208
+ </div>
209
+ </div>
210
+ </div>
211
  </div>
212
  </div>
213
  </div>
seg_app/ui_slicer/state.js CHANGED
@@ -65,6 +65,11 @@ class AppState {
65
  maskLoaded: false,
66
  maskData: null, // Uint8Array
67
 
 
 
 
 
 
68
  // Display settings
69
  windowWidth: 400,
70
  windowLevel: 40,
@@ -215,6 +220,9 @@ class AppState {
215
  volumeData: null,
216
  maskLoaded: false,
217
  maskData: null,
 
 
 
218
  prompts: []
219
  });
220
  }
@@ -330,6 +338,47 @@ class AppState {
330
  });
331
  }
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  // ============ Loading State ============
334
 
335
  /**
 
65
  maskLoaded: false,
66
  maskData: null, // Uint8Array
67
 
68
+ // Ground Truth (for Dice comparison)
69
+ gtLoaded: false,
70
+ gtMaskData: null, // Uint8Array
71
+ overlapSliceIndex: 0, // Slice index for overlap viewer
72
+
73
  // Display settings
74
  windowWidth: 400,
75
  windowLevel: 40,
 
220
  volumeData: null,
221
  maskLoaded: false,
222
  maskData: null,
223
+ gtLoaded: false,
224
+ gtMaskData: null,
225
+ overlapSliceIndex: 0,
226
  prompts: []
227
  });
228
  }
 
338
  });
339
  }
340
 
341
+ // ============ Ground Truth Methods ============
342
+
343
+ /**
344
+ * Set ground truth mask data
345
+ * @param {Uint8Array} data
346
+ */
347
+ setGroundTruth(data) {
348
+ // Set overlap slice to center initially
349
+ const shape = this.get('volumeShape');
350
+ const centerSlice = shape ? Math.floor(shape[2] / 2) : 0;
351
+
352
+ this.update({
353
+ gtLoaded: true,
354
+ gtMaskData: data,
355
+ overlapSliceIndex: centerSlice
356
+ });
357
+ }
358
+
359
+ /**
360
+ * Clear ground truth data
361
+ */
362
+ clearGroundTruth() {
363
+ this.update({
364
+ gtLoaded: false,
365
+ gtMaskData: null,
366
+ overlapSliceIndex: 0
367
+ });
368
+ }
369
+
370
+ /**
371
+ * Set overlap viewer slice index
372
+ * @param {number} index
373
+ */
374
+ setOverlapSlice(index) {
375
+ const shape = this.get('volumeShape');
376
+ if (!shape) return;
377
+ const maxIndex = shape[2] - 1; // W dimension (axial slices through Z)
378
+ const clampedIndex = Math.max(0, Math.min(index, maxIndex));
379
+ this.set('overlapSliceIndex', clampedIndex);
380
+ }
381
+
382
  // ============ Loading State ============
383
 
384
  /**
seg_app/ui_slicer/styles.css CHANGED
@@ -17,6 +17,11 @@
17
  --border-color: #2a2a4a;
18
  --shadow-color: rgba(0, 0, 0, 0.3);
19
 
 
 
 
 
 
20
  /* View colors */
21
  --axial-color: #4caf50;
22
  --sagittal-color: #2196f3;
@@ -484,46 +489,178 @@ body {
484
  cursor: crosshair;
485
  }
486
 
487
- /* Info Panel (4th quadrant) */
488
- .info-panel {
489
  flex: 1;
490
- padding: var(--spacing-md);
 
 
 
 
 
 
 
 
491
  overflow-y: auto;
 
 
492
  }
493
 
494
- .info-panel h4 {
495
- font-size: 0.8rem;
496
- margin-bottom: var(--spacing-sm);
497
- color: var(--text-primary);
 
 
498
  }
499
 
500
- .info-panel p {
 
501
  font-size: 0.75rem;
502
- color: var(--text-secondary);
503
  margin-bottom: var(--spacing-xs);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  }
505
 
506
  .crosshair-info,
507
  .seg-info,
 
508
  .keyboard-shortcuts {
509
- margin-bottom: var(--spacing-md);
510
- padding-bottom: var(--spacing-md);
511
  border-bottom: 1px solid var(--border-color);
512
  }
513
 
 
 
 
 
 
 
 
 
 
 
514
  .keyboard-shortcuts:last-child {
515
  border-bottom: none;
516
  }
517
 
 
 
 
 
518
  kbd {
519
  display: inline-block;
520
- padding: 2px 6px;
521
- font-size: 0.7rem;
522
  font-family: monospace;
523
  background: var(--bg-tertiary);
524
  border: 1px solid var(--border-color);
525
  border-radius: 3px;
526
- min-width: 50px;
527
  text-align: center;
528
  }
529
 
 
17
  --border-color: #2a2a4a;
18
  --shadow-color: rgba(0, 0, 0, 0.3);
19
 
20
+ /* Semantic colors for metrics */
21
+ --success-color: #4caf50;
22
+ --warning-color: #ff9800;
23
+ --error-color: #f44336;
24
+
25
  /* View colors */
26
  --axial-color: #4caf50;
27
  --sagittal-color: #2196f3;
 
489
  cursor: crosshair;
490
  }
491
 
492
+ /* Info Panel (4th quadrant) - Split Layout */
493
+ .info-panel-split {
494
  flex: 1;
495
+ display: flex;
496
+ gap: var(--spacing-sm);
497
+ padding: var(--spacing-sm);
498
+ overflow: hidden;
499
+ }
500
+
501
+ .info-left {
502
+ flex: 1;
503
+ min-width: 0;
504
  overflow-y: auto;
505
+ padding-right: var(--spacing-sm);
506
+ border-right: 1px solid var(--border-color);
507
  }
508
 
509
+ .info-right {
510
+ flex: 1;
511
+ min-width: 0;
512
+ display: flex;
513
+ flex-direction: column;
514
+ overflow-y: auto;
515
  }
516
 
517
+ .info-left h4,
518
+ .info-right h4 {
519
  font-size: 0.75rem;
 
520
  margin-bottom: var(--spacing-xs);
521
+ color: var(--text-primary);
522
+ }
523
+
524
+ .info-left p {
525
+ font-size: 0.7rem;
526
+ color: var(--text-secondary);
527
+ margin-bottom: 2px;
528
+ }
529
+
530
+ /* Overlap Viewer */
531
+ .overlap-viewer {
532
+ display: flex;
533
+ flex-direction: column;
534
+ height: 100%;
535
+ }
536
+
537
+ .overlap-canvas-container {
538
+ position: relative;
539
+ flex: 1;
540
+ min-height: 160px;
541
+ background: var(--bg-primary);
542
+ border: 1px solid var(--border-color);
543
+ border-radius: var(--radius-sm);
544
+ display: flex;
545
+ align-items: center;
546
+ justify-content: center;
547
+ overflow: hidden;
548
+ cursor: ns-resize;
549
+ }
550
+
551
+ #overlap-canvas {
552
+ width: 100%;
553
+ height: 100%;
554
+ display: block;
555
+ }
556
+
557
+ .overlap-canvas-container .slice-indicator {
558
+ position: absolute;
559
+ bottom: 4px;
560
+ left: 50%;
561
+ transform: translateX(-50%);
562
+ background: rgba(0, 0, 0, 0.7);
563
+ color: var(--text-primary);
564
+ padding: 2px 8px;
565
+ border-radius: var(--radius-sm);
566
+ font-size: 0.65rem;
567
+ font-family: var(--font-mono);
568
+ pointer-events: none;
569
+ z-index: 10;
570
+ }
571
+
572
+ .overlap-placeholder {
573
+ position: absolute;
574
+ font-size: 0.7rem;
575
+ color: var(--text-secondary);
576
+ text-align: center;
577
+ padding: var(--spacing-sm);
578
+ }
579
+
580
+ .overlap-placeholder small {
581
+ font-size: 0.6rem;
582
+ opacity: 0.7;
583
+ }
584
+
585
+ .overlap-placeholder.hidden {
586
+ display: none;
587
+ }
588
+
589
+ /* Overlap Legend */
590
+ .overlap-legend {
591
+ margin-top: var(--spacing-xs);
592
+ display: grid;
593
+ grid-template-columns: 1fr 1fr;
594
+ gap: 2px;
595
+ font-size: 0.6rem;
596
+ }
597
+
598
+ .legend-item {
599
+ display: flex;
600
+ align-items: center;
601
+ gap: 4px;
602
+ color: var(--text-secondary);
603
+ }
604
+
605
+ .legend-color {
606
+ width: 10px;
607
+ height: 10px;
608
+ border-radius: 2px;
609
+ flex-shrink: 0;
610
+ }
611
+
612
+ .legend-color.tp {
613
+ background: #4caf50; /* Green - True Positive */
614
+ }
615
+
616
+ .legend-color.fp {
617
+ background: #f44336; /* Red - False Positive */
618
+ }
619
+
620
+ .legend-color.fn {
621
+ background: #2196f3; /* Blue - False Negative */
622
+ }
623
+
624
+ .legend-color.pred-only {
625
+ background: #ff9800; /* Orange - Prediction only (no GT) */
626
  }
627
 
628
  .crosshair-info,
629
  .seg-info,
630
+ .dice-info,
631
  .keyboard-shortcuts {
632
+ margin-bottom: var(--spacing-sm);
633
+ padding-bottom: var(--spacing-sm);
634
  border-bottom: 1px solid var(--border-color);
635
  }
636
 
637
+ .dice-info #gt-status {
638
+ font-weight: 500;
639
+ }
640
+
641
+ .dice-info .btn-small {
642
+ margin-top: var(--spacing-xs);
643
+ padding: 3px 6px;
644
+ font-size: 0.6rem;
645
+ }
646
+
647
  .keyboard-shortcuts:last-child {
648
  border-bottom: none;
649
  }
650
 
651
+ .keyboard-shortcuts p {
652
+ font-size: 0.6rem;
653
+ }
654
+
655
  kbd {
656
  display: inline-block;
657
+ padding: 1px 4px;
658
+ font-size: 0.6rem;
659
  font-family: monospace;
660
  background: var(--bg-tertiary);
661
  border: 1px solid var(--border-color);
662
  border-radius: 3px;
663
+ min-width: 40px;
664
  text-align: center;
665
  }
666