Spaces:
Sleeping
Sleeping
Commit ·
d2c2086
1
Parent(s): cf575e8
feat: Add nnUNet integration, Dice metrics, and split INFO panel with overlap viewer
Browse files- Add nnUNet v2 wrapper (nnunet_wrapper.py) for brain lesion segmentation
- Register nnUNet and fix unet3d-brain-tumor model in registry
- Add Dice score, IoU, sensitivity, precision metrics computation
- Add GT upload endpoint and metrics API endpoints
- Split INFO panel: left (metrics), right (mask comparison viewer)
- Overlap viewer shows TP (green), FP (red), FN (blue) visualization
- Add scroll-based slice navigation for overlap viewer
- Set 3D U-Net (Baseline) as default model selection
- Reset mask/metrics/overlap when uploading new volume
- Update README with local run instructions
- README.md +173 -0
- app.py +24 -0
- requirements.txt +6 -0
- seg_app/backend/api.py +219 -10
- seg_app/config/settings.py +41 -0
- seg_app/inference/model_registry.py +8 -0
- seg_app/inference/orchestrator.py +14 -2
- seg_app/metrics/segmentation_metrics.py +138 -0
- seg_app/models/monai_autoseg.py +19 -0
- seg_app/models/nnunet_wrapper.py +399 -0
- seg_app/ui_slicer/api-client.js +58 -0
- seg_app/ui_slicer/app.js +375 -3
- seg_app/ui_slicer/index.html +70 -20
- seg_app/ui_slicer/state.js +49 -0
- seg_app/ui_slicer/styles.css +151 -14
README.md
CHANGED
|
@@ -363,3 +363,176 @@ Key endpoints:
|
|
| 363 |
- `POST /refine` – Refine segmentation with additional prompts
|
| 364 |
- `GET /mask/{volume_id}/data` – Download raw mask data
|
| 365 |
- `GET /health` – Health check endpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
- `POST /refine` – Refine segmentation with additional prompts
|
| 364 |
- `GET /mask/{volume_id}/data` – Download raw mask data
|
| 365 |
- `GET /health` – Health check endpoint
|
| 366 |
+
|
| 367 |
+
---
|
| 368 |
+
|
| 369 |
+
## 🧬 Integrating Your Own nnUNet Model
|
| 370 |
+
|
| 371 |
+
This section explains how to integrate a trained nnUNet v2 model into the web application as a baseline model option.
|
| 372 |
+
|
| 373 |
+
### Prerequisites
|
| 374 |
+
|
| 375 |
+
1. **Trained nnUNet model**: You need a completed nnUNet v2 training run with:
|
| 376 |
+
- Code location (for reference): e.g., `/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet/`
|
| 377 |
+
- **Checkpoints directory**: e.g., `/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/`
|
| 378 |
+
|
| 379 |
+
2. **nnUNet v2 installed**: The nnunetv2 package must be installed in your conda environment
|
| 380 |
+
|
| 381 |
+
### Step 1: Install nnUNet v2
|
| 382 |
+
|
| 383 |
+
First, install nnUNet v2 in your conda environment:
|
| 384 |
+
|
| 385 |
+
```powershell
|
| 386 |
+
# Activate your environment
|
| 387 |
+
conda activate seg_app
|
| 388 |
+
|
| 389 |
+
# Install nnUNet v2
|
| 390 |
+
pip install nnunetv2
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
Or uncomment the line in `requirements.txt`:
|
| 394 |
+
|
| 395 |
+
```pip
|
| 396 |
+
# nnunetv2>=2.2
|
| 397 |
+
```
|
| 398 |
+
|
| 399 |
+
And run:
|
| 400 |
+
```powershell
|
| 401 |
+
pip install -r requirements.txt
|
| 402 |
+
```
|
| 403 |
+
|
| 404 |
+
### Step 2: Locate Your nnUNet Checkpoint Path
|
| 405 |
+
|
| 406 |
+
nnUNet training creates a specific folder structure. You need the path to the **trainer output folder**:
|
| 407 |
+
|
| 408 |
+
```
|
| 409 |
+
nnUNet_results/
|
| 410 |
+
└── Dataset###_Name/ # e.g., Dataset001_BrainLesion
|
| 411 |
+
└── nnUNetTrainer__nnUNetPlans__3d_fullres/ # ← THIS IS THE PATH YOU NEED
|
| 412 |
+
├── plans.json # Training plans
|
| 413 |
+
├── dataset.json # Dataset configuration
|
| 414 |
+
├── fold_0/
|
| 415 |
+
│ ├── checkpoint_final.pth # Model weights
|
| 416 |
+
│ └── checkpoint_best.pth # Best validation weights
|
| 417 |
+
├── fold_1/
|
| 418 |
+
│ └── ...
|
| 419 |
+
└── ...
|
| 420 |
+
```
|
| 421 |
+
|
| 422 |
+
**Your checkpoint path** should point to the trainer folder, for example:
|
| 423 |
+
```
|
| 424 |
+
/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres
|
| 425 |
+
```
|
| 426 |
+
|
| 427 |
+
### Step 3: Configure nnUNet in Settings
|
| 428 |
+
|
| 429 |
+
Edit the file `seg_app/config/settings.py` and update the `NNUNET_CONFIG`:
|
| 430 |
+
|
| 431 |
+
```python
|
| 432 |
+
# Global nnUNet configuration instance
|
| 433 |
+
NNUNET_CONFIG = nnUNetConfig(
|
| 434 |
+
# Set the path to your trained nnUNet model folder
|
| 435 |
+
checkpoint_path="/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres",
|
| 436 |
+
|
| 437 |
+
# Optional: which folds to use for inference
|
| 438 |
+
# "all" = use all available folds (ensemble), or [0] for single fold
|
| 439 |
+
use_folds="all",
|
| 440 |
+
|
| 441 |
+
# Optional: test-time augmentation (mirroring)
|
| 442 |
+
# True = more accurate but slower, False = faster inference
|
| 443 |
+
use_mirroring=True,
|
| 444 |
+
|
| 445 |
+
# Optional: customize the display name in the UI dropdown
|
| 446 |
+
display_name="nnU-Net (Brain Lesion)",
|
| 447 |
+
)
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
### Step 4: Verify the Integration
|
| 451 |
+
|
| 452 |
+
After configuration, restart the application and verify nnUNet appears in the model dropdown:
|
| 453 |
+
|
| 454 |
+
```powershell
|
| 455 |
+
# For Gradio UI
|
| 456 |
+
python app.py
|
| 457 |
+
|
| 458 |
+
# For VTK.js Slicer (FastAPI backend)
|
| 459 |
+
uvicorn seg_app.backend.api:app --reload --port 8000
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
Open the web interface and check that **"nnU-Net (Brain Lesion)"** appears in the model selection dropdown.
|
| 463 |
+
|
| 464 |
+
### Step 5: Test Inference
|
| 465 |
+
|
| 466 |
+
1. Upload a NIfTI volume (`.nii` or `.nii.gz`)
|
| 467 |
+
2. Select **"nnU-Net (Brain Lesion)"** from the model dropdown
|
| 468 |
+
3. Click **"Run Segmentation"**
|
| 469 |
+
4. View the segmentation overlay and volume metrics
|
| 470 |
+
|
| 471 |
+
### nnUNet Configuration Options
|
| 472 |
+
|
| 473 |
+
| Option | Default | Description |
|
| 474 |
+
|--------|---------|-------------|
|
| 475 |
+
| `checkpoint_path` | `None` | **Required.** Path to nnUNet trainer output folder |
|
| 476 |
+
| `use_folds` | `"all"` | Which folds to use. `"all"` for ensemble, `[0]` for single fold |
|
| 477 |
+
| `use_mirroring` | `True` | Test-time augmentation. Improves accuracy but ~4x slower |
|
| 478 |
+
| `display_name` | `"nnU-Net (Brain Lesion)"` | Name shown in UI dropdown |
|
| 479 |
+
|
| 480 |
+
### Advanced: Multiple nnUNet Models
|
| 481 |
+
|
| 482 |
+
To add multiple nnUNet models for different tasks, you can modify the registration logic in `seg_app/models/nnunet_wrapper.py`. The `register_nnunet()` function can be extended to register multiple models:
|
| 483 |
+
|
| 484 |
+
```python
|
| 485 |
+
# Example: Register multiple nnUNet models
|
| 486 |
+
def register_nnunet() -> None:
|
| 487 |
+
from seg_app.inference.model_registry import register_model
|
| 488 |
+
|
| 489 |
+
# Model 1: Brain Lesion
|
| 490 |
+
brain_config = ModelConfig(
|
| 491 |
+
model_id="nnunet-brain-lesion",
|
| 492 |
+
local_path="/path/to/nnUNet_results/DatasetBrain/nnUNetTrainer__nnUNetPlans__3d_fullres",
|
| 493 |
+
device="cuda",
|
| 494 |
+
)
|
| 495 |
+
register_model("nnunet-brain-lesion", nnUNetWrapper, brain_config)
|
| 496 |
+
|
| 497 |
+
# Model 2: Liver (example)
|
| 498 |
+
liver_config = ModelConfig(
|
| 499 |
+
model_id="nnunet-liver",
|
| 500 |
+
local_path="/path/to/nnUNet_results/DatasetLiver/nnUNetTrainer__nnUNetPlans__3d_fullres",
|
| 501 |
+
device="cuda",
|
| 502 |
+
)
|
| 503 |
+
register_model("nnunet-liver", nnUNetWrapper, liver_config)
|
| 504 |
+
```
|
| 505 |
+
|
| 506 |
+
### Troubleshooting nnUNet Integration
|
| 507 |
+
|
| 508 |
+
**"nnunetv2 is not installed"**
|
| 509 |
+
```powershell
|
| 510 |
+
pip install nnunetv2
|
| 511 |
+
```
|
| 512 |
+
|
| 513 |
+
**"plans.json not found"**
|
| 514 |
+
- Ensure `checkpoint_path` points to the trainer folder (not the dataset folder)
|
| 515 |
+
- The path should contain `plans.json` or `nnUNetPlans.json`
|
| 516 |
+
|
| 517 |
+
**"checkpoint_final.pth not found"**
|
| 518 |
+
- Verify training completed successfully
|
| 519 |
+
- Check if only `checkpoint_best.pth` exists (modify the wrapper to use it)
|
| 520 |
+
|
| 521 |
+
**"Model not appearing in dropdown"**
|
| 522 |
+
- Check that `checkpoint_path` is not `None` in `NNUNET_CONFIG`
|
| 523 |
+
- Look for warnings in the terminal when starting the app
|
| 524 |
+
|
| 525 |
+
**Out of memory during inference**
|
| 526 |
+
- Reduce `use_mirroring` to `False` (reduces memory by ~4x)
|
| 527 |
+
- Use fewer folds: `use_folds=[0]` instead of `"all"`
|
| 528 |
+
- Reduce input volume size
|
| 529 |
+
|
| 530 |
+
### Files Modified for nnUNet Integration
|
| 531 |
+
|
| 532 |
+
| File | Purpose |
|
| 533 |
+
|------|---------|
|
| 534 |
+
| [seg_app/models/nnunet_wrapper.py](seg_app/models/nnunet_wrapper.py) | nnUNet model wrapper implementing BaseModel interface |
|
| 535 |
+
| [seg_app/config/settings.py](seg_app/config/settings.py) | nnUNet configuration (checkpoint path, inference settings) |
|
| 536 |
+
| [seg_app/inference/model_registry.py](seg_app/inference/model_registry.py) | Registers nnUNet model during lazy initialization |
|
| 537 |
+
| [seg_app/inference/orchestrator.py](seg_app/inference/orchestrator.py) | Adds nnUNet to available models list |
|
| 538 |
+
| [requirements.txt](requirements.txt) | Optional nnunetv2 dependency |
|
app.py
CHANGED
|
@@ -131,3 +131,27 @@ if __name__ == "__main__":
|
|
| 131 |
server_port=7869,
|
| 132 |
share=False,
|
| 133 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
server_port=7869,
|
| 132 |
share=False,
|
| 133 |
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
>> First Launch API server backend: Start the FastAPI backend with uvicorn
|
| 139 |
+
# uvicorn seg_app.api.app:app --reload --host 127.0.0.1 --port 8000
|
| 140 |
+
uvicorn seg_app.backend.api:app --reload --host 127.0.0.1 --port 8000
|
| 141 |
+
|
| 142 |
+
>> Second, Start the FastAPI Frontend:
|
| 143 |
+
cd seg_app/ui_slicer
|
| 144 |
+
python serve_frontend.py
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
>> Launch Gradio App for seg_app.
|
| 148 |
+
python app.py
|
| 149 |
+
|
| 150 |
+
# Terminal 1: FastAPI Backend
|
| 151 |
+
uvicorn seg_app.backend.api:app --reload --host 127.0.0.1 --port 8000
|
| 152 |
+
|
| 153 |
+
# Terminal 2: Frontend Server
|
| 154 |
+
cd seg_app/ui_slicer
|
| 155 |
+
python serve_frontend.py
|
| 156 |
+
|
| 157 |
+
"""
|
requirements.txt
CHANGED
|
@@ -42,3 +42,9 @@ huggingface_hub>=0.20.0,<0.28.0
|
|
| 42 |
# =============================================================================
|
| 43 |
einops>=0.6.0
|
| 44 |
timm>=0.9.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# =============================================================================
|
| 43 |
einops>=0.6.0
|
| 44 |
timm>=0.9.0
|
| 45 |
+
|
| 46 |
+
# =============================================================================
|
| 47 |
+
# nnUNet v2 (optional - for local nnUNet models)
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# Uncomment to enable nnUNet support:
|
| 50 |
+
# nnunetv2>=2.2
|
seg_app/backend/api.py
CHANGED
|
@@ -108,6 +108,30 @@ class AvailableModelsResponse(BaseModel):
|
|
| 108 |
models: List[Dict[str, str]]
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# =============================================================================
|
| 112 |
# In-Memory State Storage
|
| 113 |
# =============================================================================
|
|
@@ -120,6 +144,7 @@ class VolumeState:
|
|
| 120 |
spacing: Tuple[float, float, float]
|
| 121 |
affine: Optional[np.ndarray] = None
|
| 122 |
mask: Optional[np.ndarray] = None
|
|
|
|
| 123 |
last_model_id: Optional[str] = None
|
| 124 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 125 |
|
|
@@ -175,6 +200,17 @@ class StateManager:
|
|
| 175 |
return True
|
| 176 |
return False
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
def list_volumes(self) -> List[str]:
|
| 179 |
"""List all stored volume IDs."""
|
| 180 |
return list(self._volumes.keys())
|
|
@@ -692,6 +728,181 @@ def create_api_app() -> FastAPI:
|
|
| 692 |
}
|
| 693 |
)
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
return app
|
| 696 |
|
| 697 |
|
|
@@ -704,13 +915,11 @@ if __name__ == "__main__":
|
|
| 704 |
uvicorn.run(app, host="127.0.0.1", port=8000)
|
| 705 |
|
| 706 |
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
python serve_frontend.py
|
| 716 |
-
"""
|
|
|
|
| 108 |
models: List[Dict[str, str]]
|
| 109 |
|
| 110 |
|
| 111 |
+
class MetricsResponse(BaseModel):
|
| 112 |
+
"""Segmentation metrics including Dice score."""
|
| 113 |
+
volume_id: str
|
| 114 |
+
has_ground_truth: bool
|
| 115 |
+
voxel_count: int
|
| 116 |
+
volume_mm3: float
|
| 117 |
+
volume_ml: float
|
| 118 |
+
# Dice metrics (only if ground truth is available)
|
| 119 |
+
dice_score: Optional[float] = None
|
| 120 |
+
iou_score: Optional[float] = None
|
| 121 |
+
sensitivity: Optional[float] = None
|
| 122 |
+
precision: Optional[float] = None
|
| 123 |
+
true_positives: Optional[int] = None
|
| 124 |
+
false_positives: Optional[int] = None
|
| 125 |
+
false_negatives: Optional[int] = None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class GroundTruthUploadResponse(BaseModel):
|
| 129 |
+
"""Response after uploading ground truth mask."""
|
| 130 |
+
volume_id: str
|
| 131 |
+
gt_shape: Tuple[int, int, int]
|
| 132 |
+
message: str = "Ground truth uploaded successfully"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
# =============================================================================
|
| 136 |
# In-Memory State Storage
|
| 137 |
# =============================================================================
|
|
|
|
| 144 |
spacing: Tuple[float, float, float]
|
| 145 |
affine: Optional[np.ndarray] = None
|
| 146 |
mask: Optional[np.ndarray] = None
|
| 147 |
+
ground_truth: Optional[np.ndarray] = None # For Dice score computation
|
| 148 |
last_model_id: Optional[str] = None
|
| 149 |
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 150 |
|
|
|
|
| 200 |
return True
|
| 201 |
return False
|
| 202 |
|
| 203 |
+
def update_ground_truth(
|
| 204 |
+
self,
|
| 205 |
+
volume_id: str,
|
| 206 |
+
ground_truth: np.ndarray,
|
| 207 |
+
) -> None:
|
| 208 |
+
"""Update the ground truth mask for a volume."""
|
| 209 |
+
if volume_id not in self._volumes:
|
| 210 |
+
raise KeyError(f"Volume not found: {volume_id}")
|
| 211 |
+
self._volumes[volume_id].ground_truth = ground_truth
|
| 212 |
+
logger.info(f"Updated ground truth for volume {volume_id}")
|
| 213 |
+
|
| 214 |
def list_volumes(self) -> List[str]:
|
| 215 |
"""List all stored volume IDs."""
|
| 216 |
return list(self._volumes.keys())
|
|
|
|
| 728 |
}
|
| 729 |
)
|
| 730 |
|
| 731 |
+
# -------------------------------------------------------------------------
|
| 732 |
+
# Ground Truth Upload (for Dice score computation)
|
| 733 |
+
# -------------------------------------------------------------------------
|
| 734 |
+
|
| 735 |
+
@app.post("/ground-truth/{volume_id}", response_model=GroundTruthUploadResponse)
|
| 736 |
+
async def upload_ground_truth(volume_id: str, file: UploadFile = File(...)):
|
| 737 |
+
"""Upload a ground truth segmentation mask for Dice score computation.
|
| 738 |
+
|
| 739 |
+
The ground truth mask must match the shape of the uploaded volume.
|
| 740 |
+
Accepts .nii or .nii.gz files.
|
| 741 |
+
"""
|
| 742 |
+
# Validate volume exists
|
| 743 |
+
volume_state = _state_manager.get_volume(volume_id)
|
| 744 |
+
if volume_state is None:
|
| 745 |
+
raise HTTPException(
|
| 746 |
+
status_code=404,
|
| 747 |
+
detail=f"Volume not found: {volume_id}. Please upload volume first."
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# Validate file extension
|
| 751 |
+
filename = file.filename or ""
|
| 752 |
+
if not (filename.endswith(".nii") or filename.endswith(".nii.gz")):
|
| 753 |
+
raise HTTPException(
|
| 754 |
+
status_code=400,
|
| 755 |
+
detail="Invalid file format. Please upload a NIfTI file (.nii or .nii.gz)"
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
try:
|
| 759 |
+
# Save to temp file
|
| 760 |
+
suffix = ".nii.gz" if filename.endswith(".nii.gz") else ".nii"
|
| 761 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 762 |
+
content = await file.read()
|
| 763 |
+
tmp.write(content)
|
| 764 |
+
tmp_path = tmp.name
|
| 765 |
+
|
| 766 |
+
# Load with nibabel
|
| 767 |
+
import nibabel as nib
|
| 768 |
+
img = nib.load(tmp_path)
|
| 769 |
+
gt_array = np.asarray(img.dataobj).copy()
|
| 770 |
+
del img
|
| 771 |
+
|
| 772 |
+
# Clean up temp file
|
| 773 |
+
try:
|
| 774 |
+
Path(tmp_path).unlink(missing_ok=True)
|
| 775 |
+
except PermissionError:
|
| 776 |
+
logger.warning(f"Could not delete temp file: {tmp_path}")
|
| 777 |
+
|
| 778 |
+
# Handle 4D with trivial dimension
|
| 779 |
+
if gt_array.ndim == 4 and gt_array.shape[-1] == 1:
|
| 780 |
+
gt_array = gt_array[..., 0]
|
| 781 |
+
|
| 782 |
+
# Validate shape matches volume
|
| 783 |
+
if gt_array.shape != volume_state.array.shape:
|
| 784 |
+
raise HTTPException(
|
| 785 |
+
status_code=400,
|
| 786 |
+
detail=f"Shape mismatch: ground truth {gt_array.shape} vs volume {volume_state.array.shape}"
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
# Binarize (any non-zero is foreground)
|
| 790 |
+
gt_array = (gt_array > 0).astype(np.uint8)
|
| 791 |
+
|
| 792 |
+
# Store ground truth
|
| 793 |
+
_state_manager.update_ground_truth(volume_id, gt_array)
|
| 794 |
+
|
| 795 |
+
return GroundTruthUploadResponse(
|
| 796 |
+
volume_id=volume_id,
|
| 797 |
+
gt_shape=gt_array.shape,
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
except HTTPException:
|
| 801 |
+
raise
|
| 802 |
+
except Exception as e:
|
| 803 |
+
logger.error(f"Failed to load ground truth: {e}")
|
| 804 |
+
raise HTTPException(
|
| 805 |
+
status_code=400,
|
| 806 |
+
detail=f"Failed to load ground truth: {str(e)}"
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# -------------------------------------------------------------------------
|
| 810 |
+
# Ground Truth Mask Data Endpoint (raw bytes for overlap viewer)
|
| 811 |
+
# -------------------------------------------------------------------------
|
| 812 |
+
|
| 813 |
+
@app.get("/ground-truth/{volume_id}/data")
|
| 814 |
+
async def get_ground_truth_data(volume_id: str):
|
| 815 |
+
"""Get raw ground truth mask data as Uint8 array.
|
| 816 |
+
|
| 817 |
+
Returns the ground truth data as a raw binary Uint8Array,
|
| 818 |
+
suitable for direct use in overlap visualization.
|
| 819 |
+
"""
|
| 820 |
+
volume_state = _state_manager.get_volume(volume_id)
|
| 821 |
+
if volume_state is None:
|
| 822 |
+
raise HTTPException(
|
| 823 |
+
status_code=404,
|
| 824 |
+
detail=f"Volume not found: {volume_id}"
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
if volume_state.ground_truth is None:
|
| 828 |
+
raise HTTPException(
|
| 829 |
+
status_code=404,
|
| 830 |
+
detail=f"No ground truth found for volume {volume_id}"
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
# Convert to uint8 and return as raw bytes
|
| 834 |
+
data = volume_state.ground_truth.astype(np.uint8)
|
| 835 |
+
|
| 836 |
+
return Response(
|
| 837 |
+
content=data.tobytes(),
|
| 838 |
+
media_type="application/octet-stream",
|
| 839 |
+
headers={
|
| 840 |
+
"Content-Disposition": f"attachment; filename=gt_{volume_id}.raw",
|
| 841 |
+
"X-GT-Shape": ",".join(map(str, volume_state.ground_truth.shape)),
|
| 842 |
+
}
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# -------------------------------------------------------------------------
|
| 846 |
+
# Metrics Endpoint (including Dice score)
|
| 847 |
+
# -------------------------------------------------------------------------
|
| 848 |
+
|
| 849 |
+
@app.get("/metrics/{volume_id}", response_model=MetricsResponse)
|
| 850 |
+
async def get_metrics(volume_id: str):
|
| 851 |
+
"""Get segmentation metrics including Dice score if ground truth is available.
|
| 852 |
+
|
| 853 |
+
Returns volume, voxel count, and overlap metrics (Dice, IoU, etc.)
|
| 854 |
+
if a ground truth mask has been uploaded.
|
| 855 |
+
"""
|
| 856 |
+
volume_state = _state_manager.get_volume(volume_id)
|
| 857 |
+
if volume_state is None:
|
| 858 |
+
raise HTTPException(
|
| 859 |
+
status_code=404,
|
| 860 |
+
detail=f"Volume not found: {volume_id}"
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
if volume_state.mask is None:
|
| 864 |
+
raise HTTPException(
|
| 865 |
+
status_code=400,
|
| 866 |
+
detail="No segmentation mask found. Run /segment first."
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# Basic metrics
|
| 870 |
+
from seg_app.metrics.segmentation_metrics import (
|
| 871 |
+
compute_segmentation_metrics,
|
| 872 |
+
compute_metrics_with_ground_truth,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
basic_metrics = compute_segmentation_metrics(
|
| 876 |
+
volume_state.mask, volume_state.spacing
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
voxel_count = sum(basic_metrics["voxel_counts"].values())
|
| 880 |
+
|
| 881 |
+
response = MetricsResponse(
|
| 882 |
+
volume_id=volume_id,
|
| 883 |
+
has_ground_truth=volume_state.ground_truth is not None,
|
| 884 |
+
voxel_count=voxel_count,
|
| 885 |
+
volume_mm3=basic_metrics["total_volume_mm3"],
|
| 886 |
+
volume_ml=basic_metrics["total_volume_ml"],
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
# Add Dice metrics if ground truth is available
|
| 890 |
+
if volume_state.ground_truth is not None:
|
| 891 |
+
gt_metrics = compute_metrics_with_ground_truth(
|
| 892 |
+
volume_state.mask,
|
| 893 |
+
volume_state.ground_truth,
|
| 894 |
+
volume_state.spacing,
|
| 895 |
+
)
|
| 896 |
+
response.dice_score = gt_metrics["dice_score"]
|
| 897 |
+
response.iou_score = gt_metrics["iou_score"]
|
| 898 |
+
response.sensitivity = gt_metrics["sensitivity"]
|
| 899 |
+
response.precision = gt_metrics["precision"]
|
| 900 |
+
response.true_positives = gt_metrics["true_positives"]
|
| 901 |
+
response.false_positives = gt_metrics["false_positives"]
|
| 902 |
+
response.false_negatives = gt_metrics["false_negatives"]
|
| 903 |
+
|
| 904 |
+
return response
|
| 905 |
+
|
| 906 |
return app
|
| 907 |
|
| 908 |
|
|
|
|
| 915 |
uvicorn.run(app, host="127.0.0.1", port=8000)
|
| 916 |
|
| 917 |
|
| 918 |
+
# Terminal 1 - Start Backend FIRST:
|
| 919 |
+
# cd e:/academia/research_projects/web_app
|
| 920 |
+
# conda activate seg_app
|
| 921 |
+
# uvicorn seg_app.backend.api:app --reload --port 8000
|
| 922 |
+
#
|
| 923 |
+
# Terminal 2 - Start Frontend SECOND:
|
| 924 |
+
# cd e:/academia/research_projects/web_app/seg_app/ui_slicer
|
| 925 |
+
# python serve_frontend.py
|
|
|
|
|
|
seg_app/config/settings.py
CHANGED
|
@@ -76,6 +76,47 @@ class HFHubConfig:
|
|
| 76 |
HF_HUB_CONFIG = HFHubConfig()
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# =============================================================================
|
| 80 |
# Application Settings
|
| 81 |
# =============================================================================
|
|
|
|
| 76 |
HF_HUB_CONFIG = HFHubConfig()
|
| 77 |
|
| 78 |
|
| 79 |
+
# =============================================================================
|
| 80 |
+
# nnUNet Configuration
|
| 81 |
+
# =============================================================================
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class nnUNetConfig:
|
| 85 |
+
"""Configuration for nnUNet models.
|
| 86 |
+
|
| 87 |
+
Set checkpoint_path to enable nnUNet in the model dropdown.
|
| 88 |
+
|
| 89 |
+
Attributes:
|
| 90 |
+
checkpoint_path: Path to the nnUNet training output folder.
|
| 91 |
+
This should point to the trainer directory, e.g.:
|
| 92 |
+
/path/to/nnUNet_results/DatasetXXX_Name/nnUNetTrainer__nnUNetPlans__3d_fullres
|
| 93 |
+
use_folds: Which folds to use for inference.
|
| 94 |
+
Options: "all" (default), or list of integers [0], [0, 1], etc.
|
| 95 |
+
use_mirroring: Whether to use test-time augmentation with mirroring.
|
| 96 |
+
Improves accuracy but increases inference time.
|
| 97 |
+
display_name: Name shown in the UI model dropdown.
|
| 98 |
+
"""
|
| 99 |
+
# Path to nnUNet checkpoint folder
|
| 100 |
+
# Set this to your local nnUNet_results path to enable nnUNet
|
| 101 |
+
# Example: "/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres"
|
| 102 |
+
checkpoint_path: Optional[str] = None
|
| 103 |
+
|
| 104 |
+
# Inference settings
|
| 105 |
+
use_folds: str = "all" # "all" or specific fold like [0]
|
| 106 |
+
use_mirroring: bool = True # Test-time augmentation (slower but more accurate)
|
| 107 |
+
|
| 108 |
+
# Display settings
|
| 109 |
+
display_name: str = "nnU-Net (Brain Lesion)"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Global nnUNet configuration instance
|
| 113 |
+
# Users should modify this path to point to their trained nnUNet model
|
| 114 |
+
NNUNET_CONFIG = nnUNetConfig(
|
| 115 |
+
# Uncomment and set this path to enable nnUNet:
|
| 116 |
+
# checkpoint_path="/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset###_Name/nnUNetTrainer__nnUNetPlans__3d_fullres",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
# =============================================================================
|
| 121 |
# Application Settings
|
| 122 |
# =============================================================================
|
seg_app/inference/model_registry.py
CHANGED
|
@@ -232,4 +232,12 @@ def _ensure_models_registered() -> None:
|
|
| 232 |
from seg_app.models.medical_sam_3d import register_medical_sam_3d
|
| 233 |
register_medical_sam_3d()
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
_models_registered = True
|
|
|
|
| 232 |
from seg_app.models.medical_sam_3d import register_medical_sam_3d
|
| 233 |
register_medical_sam_3d()
|
| 234 |
|
| 235 |
+
# Register nnUNet if configured (local checkpoint required)
|
| 236 |
+
try:
|
| 237 |
+
from seg_app.models.nnunet_wrapper import register_nnunet
|
| 238 |
+
register_nnunet()
|
| 239 |
+
except Exception as e:
|
| 240 |
+
import logging
|
| 241 |
+
logging.getLogger(__name__).debug(f"nnUNet registration skipped: {e}")
|
| 242 |
+
|
| 243 |
_models_registered = True
|
seg_app/inference/orchestrator.py
CHANGED
|
@@ -190,12 +190,24 @@ def get_available_models() -> List[Dict[str, str]]:
|
|
| 190 |
"""Get list of available models for brain lesion segmentation.
|
| 191 |
|
| 192 |
Returns:
|
| 193 |
-
List of dicts with 'id' and 'display_name' for each model
|
|
|
|
| 194 |
"""
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
{"id": "medical-sam-3d", "display_name": "Medical SAM 3D (SA-Med3D-140K)"},
|
| 197 |
{"id": "unet3d-brain-tumor", "display_name": "3D U-Net (Baseline)"},
|
| 198 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
def get_task_info(task_name: str) -> Dict[str, Any]:
|
|
|
|
| 190 |
"""Get list of available models for brain lesion segmentation.
|
| 191 |
|
| 192 |
Returns:
|
| 193 |
+
List of dicts with 'id' and 'display_name' for each model.
|
| 194 |
+
nnUNet is included only if configured in settings.
|
| 195 |
"""
|
| 196 |
+
from seg_app.config.settings import NNUNET_CONFIG
|
| 197 |
+
|
| 198 |
+
models = [
|
| 199 |
{"id": "medical-sam-3d", "display_name": "Medical SAM 3D (SA-Med3D-140K)"},
|
| 200 |
{"id": "unet3d-brain-tumor", "display_name": "3D U-Net (Baseline)"},
|
| 201 |
]
|
| 202 |
+
|
| 203 |
+
# Add nnUNet if checkpoint path is configured
|
| 204 |
+
if NNUNET_CONFIG.checkpoint_path is not None:
|
| 205 |
+
models.insert(0, {
|
| 206 |
+
"id": "nnunet-brain-lesion",
|
| 207 |
+
"display_name": NNUNET_CONFIG.display_name,
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
return models
|
| 211 |
|
| 212 |
|
| 213 |
def get_task_info(task_name: str) -> Dict[str, Any]:
|
seg_app/metrics/segmentation_metrics.py
CHANGED
|
@@ -102,6 +102,144 @@ def compute_segmentation_metrics(
|
|
| 102 |
}
|
| 103 |
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def format_metrics_for_display(metrics: Dict[str, Any]) -> Dict[str, Any]:
|
| 106 |
"""Format metrics dictionary for UI display.
|
| 107 |
|
|
|
|
| 102 |
}
|
| 103 |
|
| 104 |
|
| 105 |
+
def compute_dice_score(
|
| 106 |
+
prediction: np.ndarray,
|
| 107 |
+
ground_truth: np.ndarray,
|
| 108 |
+
smooth: float = 1e-6,
|
| 109 |
+
) -> float:
|
| 110 |
+
"""Compute the Dice similarity coefficient between prediction and ground truth.
|
| 111 |
+
|
| 112 |
+
Dice = 2 * |P ∩ G| / (|P| + |G|)
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
prediction: Predicted binary segmentation mask (D, H, W)
|
| 116 |
+
ground_truth: Ground truth binary segmentation mask (D, H, W)
|
| 117 |
+
smooth: Smoothing factor to avoid division by zero
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Dice score in range [0, 1], where 1 is perfect overlap
|
| 121 |
+
|
| 122 |
+
Raises:
|
| 123 |
+
ValueError: If masks have different shapes
|
| 124 |
+
"""
|
| 125 |
+
if prediction.shape != ground_truth.shape:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"Shape mismatch: prediction {prediction.shape} vs ground_truth {ground_truth.shape}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Binarize masks
|
| 131 |
+
pred_binary = (prediction > 0).astype(np.float32)
|
| 132 |
+
gt_binary = (ground_truth > 0).astype(np.float32)
|
| 133 |
+
|
| 134 |
+
# Compute intersection and unions
|
| 135 |
+
intersection = np.sum(pred_binary * gt_binary)
|
| 136 |
+
pred_sum = np.sum(pred_binary)
|
| 137 |
+
gt_sum = np.sum(gt_binary)
|
| 138 |
+
|
| 139 |
+
# Dice formula with smoothing
|
| 140 |
+
dice = (2.0 * intersection + smooth) / (pred_sum + gt_sum + smooth)
|
| 141 |
+
|
| 142 |
+
return float(dice)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def compute_iou_score(
|
| 146 |
+
prediction: np.ndarray,
|
| 147 |
+
ground_truth: np.ndarray,
|
| 148 |
+
smooth: float = 1e-6,
|
| 149 |
+
) -> float:
|
| 150 |
+
"""Compute the Intersection over Union (Jaccard index) between masks.
|
| 151 |
+
|
| 152 |
+
IoU = |P ∩ G| / |P ∪ G|
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
prediction: Predicted binary segmentation mask (D, H, W)
|
| 156 |
+
ground_truth: Ground truth binary segmentation mask (D, H, W)
|
| 157 |
+
smooth: Smoothing factor to avoid division by zero
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
IoU score in range [0, 1], where 1 is perfect overlap
|
| 161 |
+
"""
|
| 162 |
+
if prediction.shape != ground_truth.shape:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Shape mismatch: prediction {prediction.shape} vs ground_truth {ground_truth.shape}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
pred_binary = (prediction > 0).astype(np.float32)
|
| 168 |
+
gt_binary = (ground_truth > 0).astype(np.float32)
|
| 169 |
+
|
| 170 |
+
intersection = np.sum(pred_binary * gt_binary)
|
| 171 |
+
union = np.sum(pred_binary) + np.sum(gt_binary) - intersection
|
| 172 |
+
|
| 173 |
+
iou = (intersection + smooth) / (union + smooth)
|
| 174 |
+
|
| 175 |
+
return float(iou)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def compute_metrics_with_ground_truth(
|
| 179 |
+
prediction: np.ndarray,
|
| 180 |
+
ground_truth: np.ndarray,
|
| 181 |
+
spacing: Tuple[float, float, float],
|
| 182 |
+
) -> Dict[str, Any]:
|
| 183 |
+
"""Compute comprehensive metrics comparing prediction to ground truth.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
prediction: Predicted segmentation mask (D, H, W)
|
| 187 |
+
ground_truth: Ground truth segmentation mask (D, H, W)
|
| 188 |
+
spacing: Voxel spacing in mm as (depth, height, width)
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Dict containing all metrics including Dice, IoU, volumes, etc.
|
| 192 |
+
"""
|
| 193 |
+
# Basic metrics for prediction
|
| 194 |
+
pred_metrics = compute_segmentation_metrics(prediction, spacing)
|
| 195 |
+
gt_metrics = compute_segmentation_metrics(ground_truth, spacing)
|
| 196 |
+
|
| 197 |
+
# Overlap metrics
|
| 198 |
+
dice = compute_dice_score(prediction, ground_truth)
|
| 199 |
+
iou = compute_iou_score(prediction, ground_truth)
|
| 200 |
+
|
| 201 |
+
# Additional statistics
|
| 202 |
+
pred_binary = prediction > 0
|
| 203 |
+
gt_binary = ground_truth > 0
|
| 204 |
+
|
| 205 |
+
tp = np.sum(pred_binary & gt_binary) # True positives
|
| 206 |
+
fp = np.sum(pred_binary & ~gt_binary) # False positives
|
| 207 |
+
fn = np.sum(~pred_binary & gt_binary) # False negatives
|
| 208 |
+
tn = np.sum(~pred_binary & ~gt_binary) # True negatives
|
| 209 |
+
|
| 210 |
+
# Derived metrics
|
| 211 |
+
sensitivity = tp / (tp + fn + 1e-6) # Recall / TPR
|
| 212 |
+
specificity = tn / (tn + fp + 1e-6) # TNR
|
| 213 |
+
precision = tp / (tp + fp + 1e-6) # PPV
|
| 214 |
+
|
| 215 |
+
return {
|
| 216 |
+
# Overlap metrics
|
| 217 |
+
"dice_score": round(dice, 4),
|
| 218 |
+
"iou_score": round(iou, 4),
|
| 219 |
+
|
| 220 |
+
# Classification metrics
|
| 221 |
+
"sensitivity": round(sensitivity, 4),
|
| 222 |
+
"specificity": round(specificity, 4),
|
| 223 |
+
"precision": round(precision, 4),
|
| 224 |
+
|
| 225 |
+
# Voxel counts
|
| 226 |
+
"true_positives": int(tp),
|
| 227 |
+
"false_positives": int(fp),
|
| 228 |
+
"false_negatives": int(fn),
|
| 229 |
+
|
| 230 |
+
# Volume metrics
|
| 231 |
+
"prediction_volume_mm3": pred_metrics["total_volume_mm3"],
|
| 232 |
+
"ground_truth_volume_mm3": gt_metrics["total_volume_mm3"],
|
| 233 |
+
"volume_difference_mm3": round(
|
| 234 |
+
pred_metrics["total_volume_mm3"] - gt_metrics["total_volume_mm3"], 2
|
| 235 |
+
),
|
| 236 |
+
|
| 237 |
+
# Metadata
|
| 238 |
+
"mask_shape": prediction.shape,
|
| 239 |
+
"spacing_mm": spacing,
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
|
| 243 |
def format_metrics_for_display(metrics: Dict[str, Any]) -> Dict[str, Any]:
|
| 244 |
"""Format metrics dictionary for UI display.
|
| 245 |
|
seg_app/models/monai_autoseg.py
CHANGED
|
@@ -344,6 +344,25 @@ def register_monai_models() -> None:
|
|
| 344 |
),
|
| 345 |
)
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
# Additional models can be registered here
|
| 348 |
# register_model(
|
| 349 |
# model_id="monai-auto3dseg-spleen",
|
|
|
|
| 344 |
),
|
| 345 |
)
|
| 346 |
|
| 347 |
+
# 3D U-Net for brain lesion segmentation (baseline model)
|
| 348 |
+
register_model(
|
| 349 |
+
model_id="unet3d-brain-tumor",
|
| 350 |
+
model_class=MONAIAuto3DSeg,
|
| 351 |
+
config=ModelConfig(
|
| 352 |
+
model_id="unet3d-brain-tumor",
|
| 353 |
+
hf_hub_path="your-org/brain-tumor-model", # Placeholder HF path
|
| 354 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 355 |
+
preprocessing={
|
| 356 |
+
"target_spacing": (1.0, 1.0, 1.0),
|
| 357 |
+
"normalize": True,
|
| 358 |
+
},
|
| 359 |
+
postprocessing={
|
| 360 |
+
"threshold": 0.5,
|
| 361 |
+
"min_component_size": 50,
|
| 362 |
+
},
|
| 363 |
+
),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
# Additional models can be registered here
|
| 367 |
# register_model(
|
| 368 |
# model_id="monai-auto3dseg-spleen",
|
seg_app/models/nnunet_wrapper.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nnUNet model wrapper for brain lesion segmentation.
|
| 3 |
+
|
| 4 |
+
Provides a standardized BaseModel interface for nnUNet v2 models,
|
| 5 |
+
supporting loading from local checkpoint directories.
|
| 6 |
+
|
| 7 |
+
nnUNet v2 Reference:
|
| 8 |
+
https://github.com/MIC-DKFZ/nnUNet
|
| 9 |
+
|
| 10 |
+
Expected directory structure:
|
| 11 |
+
nnUNet_results/
|
| 12 |
+
└── DatasetXXX_Name/
|
| 13 |
+
└── nnUNetTrainer__nnUNetPlans__3d_fullres/
|
| 14 |
+
├── plans.json
|
| 15 |
+
├── dataset.json
|
| 16 |
+
├── fold_0/
|
| 17 |
+
│ └── checkpoint_final.pth
|
| 18 |
+
├── fold_1/
|
| 19 |
+
│ └── checkpoint_final.pth
|
| 20 |
+
└── ...
|
| 21 |
+
|
| 22 |
+
Usage:
|
| 23 |
+
config = ModelConfig(
|
| 24 |
+
model_id="nnunet-brain-lesion",
|
| 25 |
+
local_path="/path/to/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres",
|
| 26 |
+
)
|
| 27 |
+
model = nnUNetWrapper(config)
|
| 28 |
+
model.load()
|
| 29 |
+
mask = model(volume, spacing=(1.0, 1.0, 1.0))
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import logging
|
| 33 |
+
import os
|
| 34 |
+
import tempfile
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
from seg_app.models.base import BaseModel, ModelConfig
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class nnUNetWrapper(BaseModel):
|
| 47 |
+
"""nnUNet v2 model wrapper implementing the BaseModel interface.
|
| 48 |
+
|
| 49 |
+
This wrapper integrates nnUNet v2's inference API with the seg_app
|
| 50 |
+
orchestrator pattern. It supports:
|
| 51 |
+
|
| 52 |
+
- Loading from local checkpoint directory
|
| 53 |
+
- Single-fold or ensemble inference
|
| 54 |
+
- Automatic preprocessing (handled by nnUNet predictor)
|
| 55 |
+
- Configurable postprocessing
|
| 56 |
+
|
| 57 |
+
Note:
|
| 58 |
+
nnUNet handles its own preprocessing (resampling, normalization)
|
| 59 |
+
based on the plans.json, so preprocess() passes through the volume
|
| 60 |
+
and postprocess() handles the output conversion.
|
| 61 |
+
|
| 62 |
+
Attributes:
|
| 63 |
+
predictor: The nnUNet predictor instance (lazy-loaded)
|
| 64 |
+
use_folds: Which folds to use for inference (default: all available)
|
| 65 |
+
use_gaussian: Whether to use Gaussian weighting for patch aggregation
|
| 66 |
+
use_mirroring: Whether to use test-time augmentation with mirroring
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, config: ModelConfig):
|
| 70 |
+
"""Initialize nnUNet wrapper.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
config: Model configuration with local_path pointing to the
|
| 74 |
+
trained nnUNet model folder (e.g., nnUNetTrainer__nnUNetPlans__3d_fullres)
|
| 75 |
+
"""
|
| 76 |
+
super().__init__(config)
|
| 77 |
+
|
| 78 |
+
self.predictor = None
|
| 79 |
+
|
| 80 |
+
# Inference settings (can be overridden via config.preprocessing)
|
| 81 |
+
preprocessing = config.preprocessing or {}
|
| 82 |
+
self.use_folds: Union[str, List[int]] = preprocessing.get("use_folds", "all")
|
| 83 |
+
self.use_gaussian: bool = preprocessing.get("use_gaussian", True)
|
| 84 |
+
self.use_mirroring: bool = preprocessing.get("use_mirroring", True)
|
| 85 |
+
self.save_probabilities: bool = preprocessing.get("save_probabilities", False)
|
| 86 |
+
|
| 87 |
+
# Postprocessing settings
|
| 88 |
+
postprocessing = config.postprocessing or {}
|
| 89 |
+
self.threshold: float = postprocessing.get("threshold", 0.5)
|
| 90 |
+
|
| 91 |
+
# Cache for input metadata (needed if nnUNet changes spacing)
|
| 92 |
+
self._original_spacing: Optional[Tuple[float, float, float]] = None
|
| 93 |
+
self._original_shape: Optional[Tuple[int, int, int]] = None
|
| 94 |
+
|
| 95 |
+
def load(self) -> None:
|
| 96 |
+
"""Load nnUNet model from local checkpoint directory.
|
| 97 |
+
|
| 98 |
+
This initializes the nnUNetPredictor and loads model weights.
|
| 99 |
+
The predictor handles all internal setup (network architecture,
|
| 100 |
+
preprocessing pipeline, etc.) based on plans.json.
|
| 101 |
+
|
| 102 |
+
Raises:
|
| 103 |
+
FileNotFoundError: If local_path doesn't exist
|
| 104 |
+
ImportError: If nnunetv2 is not installed
|
| 105 |
+
RuntimeError: If model loading fails
|
| 106 |
+
"""
|
| 107 |
+
if self._is_loaded:
|
| 108 |
+
logger.info(f"Model {self.config.model_id} already loaded")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
if self.config.local_path is None:
|
| 112 |
+
raise RuntimeError(
|
| 113 |
+
f"nnUNet requires a local path to the trained model folder. "
|
| 114 |
+
f"Set local_path in ModelConfig to the nnUNetTrainer__nnUNetPlans__3d_fullres directory."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
model_folder = Path(self.config.local_path)
|
| 118 |
+
if not model_folder.exists():
|
| 119 |
+
raise FileNotFoundError(
|
| 120 |
+
f"nnUNet model folder not found: {model_folder}\n"
|
| 121 |
+
f"Expected structure: nnUNet_results/DatasetXXX_Name/nnUNetTrainer__nnUNetPlans__3d_fullres/"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Check for required files
|
| 125 |
+
plans_file = model_folder / "plans.json"
|
| 126 |
+
if not plans_file.exists():
|
| 127 |
+
# Try alternative location
|
| 128 |
+
plans_file = model_folder / "nnUNetPlans.json"
|
| 129 |
+
if not plans_file.exists():
|
| 130 |
+
raise FileNotFoundError(
|
| 131 |
+
f"plans.json not found in {model_folder}. "
|
| 132 |
+
f"Ensure this is a valid nnUNet training output folder."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
|
| 137 |
+
except ImportError:
|
| 138 |
+
raise ImportError(
|
| 139 |
+
"nnunetv2 is not installed. Install it with:\n"
|
| 140 |
+
"pip install nnunetv2\n"
|
| 141 |
+
"Or see: https://github.com/MIC-DKFZ/nnUNet"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
logger.info(f"Loading nnUNet model from: {model_folder}")
|
| 145 |
+
|
| 146 |
+
# Initialize predictor
|
| 147 |
+
self.predictor = nnUNetPredictor(
|
| 148 |
+
tile_step_size=0.5,
|
| 149 |
+
use_gaussian=self.use_gaussian,
|
| 150 |
+
use_mirroring=self.use_mirroring,
|
| 151 |
+
perform_everything_on_device=True,
|
| 152 |
+
device=torch.device(self.device),
|
| 153 |
+
verbose=False,
|
| 154 |
+
verbose_preprocessing=False,
|
| 155 |
+
allow_tqdm=False,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Determine which folds to use
|
| 159 |
+
if self.use_folds == "all":
|
| 160 |
+
folds = self._find_available_folds(model_folder)
|
| 161 |
+
else:
|
| 162 |
+
folds = self.use_folds
|
| 163 |
+
|
| 164 |
+
if not folds:
|
| 165 |
+
# Default to fold 0 if no specific folds found
|
| 166 |
+
folds = [0]
|
| 167 |
+
logger.warning(f"No fold directories found, defaulting to fold 0")
|
| 168 |
+
|
| 169 |
+
logger.info(f"Using folds: {folds}")
|
| 170 |
+
|
| 171 |
+
# Initialize from trained model folder
|
| 172 |
+
# This loads plans.json, network architecture, and weights
|
| 173 |
+
self.predictor.initialize_from_trained_model_folder(
|
| 174 |
+
model_training_output_dir=str(model_folder),
|
| 175 |
+
use_folds=folds,
|
| 176 |
+
checkpoint_name="checkpoint_final.pth", # or "checkpoint_best.pth"
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self._is_loaded = True
|
| 180 |
+
logger.info(f"nnUNet model loaded successfully on {self.device}")
|
| 181 |
+
|
| 182 |
+
def _find_available_folds(self, model_folder: Path) -> List[int]:
|
| 183 |
+
"""Find available fold directories in the model folder.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
model_folder: Path to nnUNet training output
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
List of fold numbers (e.g., [0, 1, 2, 3, 4])
|
| 190 |
+
"""
|
| 191 |
+
folds = []
|
| 192 |
+
for item in model_folder.iterdir():
|
| 193 |
+
if item.is_dir() and item.name.startswith("fold_"):
|
| 194 |
+
try:
|
| 195 |
+
fold_num = int(item.name.split("_")[1])
|
| 196 |
+
# Check if checkpoint exists
|
| 197 |
+
if (item / "checkpoint_final.pth").exists():
|
| 198 |
+
folds.append(fold_num)
|
| 199 |
+
elif (item / "checkpoint_best.pth").exists():
|
| 200 |
+
folds.append(fold_num)
|
| 201 |
+
except (ValueError, IndexError):
|
| 202 |
+
continue
|
| 203 |
+
return sorted(folds)
|
| 204 |
+
|
| 205 |
+
def preprocess(
|
| 206 |
+
self,
|
| 207 |
+
volume: np.ndarray,
|
| 208 |
+
spacing: Optional[Tuple[float, float, float]] = None,
|
| 209 |
+
) -> torch.Tensor:
|
| 210 |
+
"""Prepare volume for nnUNet inference.
|
| 211 |
+
|
| 212 |
+
Note: nnUNet's predictor handles its own preprocessing internally,
|
| 213 |
+
so this method mainly stores metadata and prepares the volume format.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
volume: 3D numpy array with shape (D, H, W)
|
| 217 |
+
spacing: Voxel spacing in mm as (depth, height, width)
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Volume as tensor (nnUNet expects numpy, but we follow interface)
|
| 221 |
+
"""
|
| 222 |
+
# Store original metadata for postprocessing
|
| 223 |
+
self._original_shape = volume.shape
|
| 224 |
+
self._original_spacing = spacing
|
| 225 |
+
|
| 226 |
+
# nnUNet expects (C, D, H, W) where C is modality/channel
|
| 227 |
+
# For single-modality MRI, add channel dimension
|
| 228 |
+
if volume.ndim == 3:
|
| 229 |
+
volume = volume[np.newaxis, ...] # (1, D, H, W)
|
| 230 |
+
|
| 231 |
+
# Return as tensor to match interface (but nnUNet will work with numpy)
|
| 232 |
+
tensor = torch.from_numpy(volume.astype(np.float32))
|
| 233 |
+
return tensor.to(self.device)
|
| 234 |
+
|
| 235 |
+
def predict(
|
| 236 |
+
self,
|
| 237 |
+
tensor: torch.Tensor,
|
| 238 |
+
prompts: Optional[Any] = None,
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Run nnUNet inference.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
tensor: Preprocessed volume tensor with shape (1, D, H, W)
|
| 244 |
+
prompts: Not used for nnUNet (ignored with warning)
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Segmentation output tensor
|
| 248 |
+
|
| 249 |
+
Raises:
|
| 250 |
+
RuntimeError: If model not loaded
|
| 251 |
+
"""
|
| 252 |
+
if not self._is_loaded or self.predictor is None:
|
| 253 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 254 |
+
|
| 255 |
+
if prompts is not None:
|
| 256 |
+
logger.warning(
|
| 257 |
+
"nnUNet does not support interactive prompts. Ignoring provided prompts. "
|
| 258 |
+
"Use Medical SAM 3D for prompt-based refinement."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Convert back to numpy for nnUNet (it works with numpy internally)
|
| 262 |
+
volume_np = tensor.cpu().numpy() # (1, D, H, W)
|
| 263 |
+
|
| 264 |
+
# Prepare properties dict for nnUNet
|
| 265 |
+
# nnUNet expects spacing in (x, y, z) = (W, H, D) order
|
| 266 |
+
if self._original_spacing is not None:
|
| 267 |
+
# Input spacing is (D, H, W), nnUNet wants (W, H, D)
|
| 268 |
+
spacing_xyz = self._original_spacing[::-1] # Reverse order
|
| 269 |
+
else:
|
| 270 |
+
# Default to 1mm isotropic if not provided
|
| 271 |
+
spacing_xyz = (1.0, 1.0, 1.0)
|
| 272 |
+
logger.warning("No spacing provided, using 1mm isotropic")
|
| 273 |
+
|
| 274 |
+
properties = {
|
| 275 |
+
"spacing": list(spacing_xyz),
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
# Run prediction using nnUNet's predict_single_npy_array
|
| 279 |
+
# This handles all internal preprocessing, sliding window, and postprocessing
|
| 280 |
+
with torch.inference_mode():
|
| 281 |
+
segmentation = self.predictor.predict_single_npy_array(
|
| 282 |
+
input_image=volume_np,
|
| 283 |
+
image_properties=properties,
|
| 284 |
+
segmentation_previous_stage=None,
|
| 285 |
+
output_file_truncated=None,
|
| 286 |
+
save_or_return_probabilities=self.save_probabilities,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Convert result to tensor
|
| 290 |
+
if isinstance(segmentation, tuple):
|
| 291 |
+
# If save_probabilities=True, returns (segmentation, probabilities)
|
| 292 |
+
segmentation = segmentation[0]
|
| 293 |
+
|
| 294 |
+
return torch.from_numpy(segmentation.astype(np.float32))
|
| 295 |
+
|
| 296 |
+
def postprocess(
|
| 297 |
+
self,
|
| 298 |
+
tensor: torch.Tensor,
|
| 299 |
+
original_shape: Tuple[int, int, int],
|
| 300 |
+
) -> np.ndarray:
|
| 301 |
+
"""Postprocess nnUNet output.
|
| 302 |
+
|
| 303 |
+
nnUNet typically outputs integer labels, so minimal postprocessing
|
| 304 |
+
is needed. This method ensures the output is a binary mask and
|
| 305 |
+
optionally resamples to the original shape if nnUNet changed it.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
tensor: Model output tensor
|
| 309 |
+
original_shape: Original volume shape (D, H, W)
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
Binary segmentation mask as numpy array
|
| 313 |
+
"""
|
| 314 |
+
mask = tensor.cpu().numpy()
|
| 315 |
+
|
| 316 |
+
# nnUNet outputs integer labels (0 = background, 1+ = foreground classes)
|
| 317 |
+
# For binary segmentation, threshold at 0.5 if probabilities,
|
| 318 |
+
# or take any non-zero value as foreground
|
| 319 |
+
if mask.dtype == np.float32 or mask.dtype == np.float64:
|
| 320 |
+
# Probability output
|
| 321 |
+
mask = (mask > self.threshold).astype(np.uint8)
|
| 322 |
+
else:
|
| 323 |
+
# Integer label output - binarize (any label > 0 is foreground)
|
| 324 |
+
mask = (mask > 0).astype(np.uint8)
|
| 325 |
+
|
| 326 |
+
# Ensure output matches original shape
|
| 327 |
+
if mask.shape != original_shape:
|
| 328 |
+
# nnUNet may change shape due to resampling
|
| 329 |
+
# Use scipy to resample back
|
| 330 |
+
from scipy.ndimage import zoom
|
| 331 |
+
|
| 332 |
+
zoom_factors = tuple(
|
| 333 |
+
o / m for o, m in zip(original_shape, mask.shape)
|
| 334 |
+
)
|
| 335 |
+
mask = zoom(mask.astype(np.float32), zoom_factors, order=0) > 0.5
|
| 336 |
+
mask = mask.astype(np.uint8)
|
| 337 |
+
|
| 338 |
+
return mask
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def register_nnunet() -> None:
|
| 342 |
+
"""Register nnUNet model in the model registry.
|
| 343 |
+
|
| 344 |
+
This function is called by the model registry during lazy initialization.
|
| 345 |
+
It registers the nnUNet wrapper with the configured checkpoint path.
|
| 346 |
+
"""
|
| 347 |
+
from seg_app.inference.model_registry import register_model
|
| 348 |
+
from seg_app.config.settings import NNUNET_CONFIG
|
| 349 |
+
|
| 350 |
+
# Only register if nnUNet checkpoint path is configured
|
| 351 |
+
if NNUNET_CONFIG.checkpoint_path is None:
|
| 352 |
+
logger.warning(
|
| 353 |
+
"nnUNet checkpoint path not configured. "
|
| 354 |
+
"Set NNUNET_CONFIG.checkpoint_path in settings.py to enable nnUNet."
|
| 355 |
+
)
|
| 356 |
+
return
|
| 357 |
+
|
| 358 |
+
config = ModelConfig(
|
| 359 |
+
model_id="nnunet-brain-lesion",
|
| 360 |
+
local_path=NNUNET_CONFIG.checkpoint_path,
|
| 361 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 362 |
+
preprocessing={
|
| 363 |
+
"use_folds": NNUNET_CONFIG.use_folds,
|
| 364 |
+
"use_gaussian": True,
|
| 365 |
+
"use_mirroring": NNUNET_CONFIG.use_mirroring,
|
| 366 |
+
},
|
| 367 |
+
postprocessing={
|
| 368 |
+
"threshold": 0.5,
|
| 369 |
+
},
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
register_model("nnunet-brain-lesion", nnUNetWrapper, config)
|
| 373 |
+
logger.info(f"Registered nnUNet model: nnunet-brain-lesion")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Convenience function for direct testing
|
| 377 |
+
def create_nnunet_model(checkpoint_path: str, device: str = "cuda") -> nnUNetWrapper:
|
| 378 |
+
"""Create an nnUNet model instance for testing.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
checkpoint_path: Path to nnUNet training output folder
|
| 382 |
+
device: Device to load model on
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Loaded nnUNet model ready for inference
|
| 386 |
+
|
| 387 |
+
Example:
|
| 388 |
+
>>> model = create_nnunet_model("/path/to/nnUNet_results/Dataset001/nnUNetTrainer__nnUNetPlans__3d_fullres")
|
| 389 |
+
>>> model.load()
|
| 390 |
+
>>> mask = model(volume, spacing=(1.0, 1.0, 1.0))
|
| 391 |
+
"""
|
| 392 |
+
config = ModelConfig(
|
| 393 |
+
model_id="nnunet-test",
|
| 394 |
+
local_path=checkpoint_path,
|
| 395 |
+
device=device,
|
| 396 |
+
)
|
| 397 |
+
model = nnUNetWrapper(config)
|
| 398 |
+
model.load()
|
| 399 |
+
return model
|
seg_app/ui_slicer/api-client.js
CHANGED
|
@@ -192,6 +192,64 @@ class ApiClient {
|
|
| 192 |
}
|
| 193 |
return await response.json();
|
| 194 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
}
|
| 196 |
|
| 197 |
// Singleton instance
|
|
|
|
| 192 |
}
|
| 193 |
return await response.json();
|
| 194 |
}
|
| 195 |
+
|
| 196 |
+
/**
|
| 197 |
+
* Upload ground truth mask for Dice score computation
|
| 198 |
+
* @param {string} volumeId - The volume ID
|
| 199 |
+
* @param {File} file - The ground truth NIfTI file
|
| 200 |
+
* @returns {Promise<{volume_id: string, gt_shape: number[], message: string}>}
|
| 201 |
+
*/
|
| 202 |
+
async uploadGroundTruth(volumeId, file) {
|
| 203 |
+
const formData = new FormData();
|
| 204 |
+
formData.append('file', file);
|
| 205 |
+
|
| 206 |
+
const response = await fetch(`${this.baseUrl}/ground-truth/${volumeId}`, {
|
| 207 |
+
method: 'POST',
|
| 208 |
+
body: formData
|
| 209 |
+
});
|
| 210 |
+
|
| 211 |
+
if (!response.ok) {
|
| 212 |
+
const error = await response.json();
|
| 213 |
+
throw new Error(error.detail || `Ground truth upload failed: ${response.status}`);
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return await response.json();
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/**
|
| 220 |
+
* Get segmentation metrics including Dice score
|
| 221 |
+
* @param {string} volumeId - The volume ID
|
| 222 |
+
* @returns {Promise<{volume_id: string, has_ground_truth: boolean, voxel_count: number, volume_mm3: number, dice_score?: number, iou_score?: number, sensitivity?: number, precision?: number}>}
|
| 223 |
+
*/
|
| 224 |
+
async getMetrics(volumeId) {
|
| 225 |
+
const response = await fetch(`${this.baseUrl}/metrics/${volumeId}`);
|
| 226 |
+
|
| 227 |
+
if (!response.ok) {
|
| 228 |
+
const error = await response.json();
|
| 229 |
+
throw new Error(error.detail || `Failed to get metrics: ${response.status}`);
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
return await response.json();
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
/**
|
| 236 |
+
* Get ground truth mask data as Uint8Array (raw bytes)
|
| 237 |
+
* @param {string} volumeId - The volume ID
|
| 238 |
+
* @returns {Promise<ArrayBuffer|null>}
|
| 239 |
+
*/
|
| 240 |
+
async getGroundTruthData(volumeId) {
|
| 241 |
+
const response = await fetch(`${this.baseUrl}/ground-truth/${volumeId}/data`);
|
| 242 |
+
|
| 243 |
+
if (!response.ok) {
|
| 244 |
+
if (response.status === 404) {
|
| 245 |
+
return null; // No GT available
|
| 246 |
+
}
|
| 247 |
+
const error = await response.json();
|
| 248 |
+
throw new Error(error.detail || `Failed to get ground truth: ${response.status}`);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
return await response.arrayBuffer();
|
| 252 |
+
}
|
| 253 |
}
|
| 254 |
|
| 255 |
// Singleton instance
|
seg_app/ui_slicer/app.js
CHANGED
|
@@ -64,7 +64,11 @@ async function loadModels() {
|
|
| 64 |
const select = document.getElementById('model-select');
|
| 65 |
select.innerHTML = '';
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
const option = document.createElement('option');
|
| 69 |
// Handle both 'id' and 'model_id' from API
|
| 70 |
const modelId = model.id || model.model_id;
|
|
@@ -72,12 +76,20 @@ async function loadModels() {
|
|
| 72 |
option.textContent = model.display_name;
|
| 73 |
select.appendChild(option);
|
| 74 |
|
| 75 |
-
// Select
|
| 76 |
-
if (
|
| 77 |
appState.set('selectedModel', modelId);
|
|
|
|
|
|
|
| 78 |
}
|
| 79 |
});
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
select.disabled = false;
|
| 82 |
updateModelInfo();
|
| 83 |
} catch (error) {
|
|
@@ -104,6 +116,9 @@ async function handleVolumeUpload(file) {
|
|
| 104 |
showLoading('Uploading volume...');
|
| 105 |
|
| 106 |
try {
|
|
|
|
|
|
|
|
|
|
| 107 |
// Upload to backend
|
| 108 |
const response = await apiClient.uploadVolume(file);
|
| 109 |
|
|
@@ -194,11 +209,17 @@ async function runSegmentation() {
|
|
| 194 |
const shape = appState.get('volumeShape');
|
| 195 |
const spacing = appState.get('volumeSpacing');
|
| 196 |
mprViewer.loadMask(maskData, shape, spacing);
|
|
|
|
|
|
|
|
|
|
| 197 |
}
|
| 198 |
|
| 199 |
hideLoading();
|
| 200 |
showToast('Success', `Segmentation complete (${response.model_id})`, 'success');
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
} catch (error) {
|
| 203 |
hideLoading();
|
| 204 |
console.error('Segmentation failed:', error);
|
|
@@ -236,11 +257,17 @@ async function runRefinement() {
|
|
| 236 |
const shape = appState.get('volumeShape');
|
| 237 |
const spacing = appState.get('volumeSpacing');
|
| 238 |
mprViewer.loadMask(maskData, shape, spacing);
|
|
|
|
|
|
|
|
|
|
| 239 |
}
|
| 240 |
|
| 241 |
hideLoading();
|
| 242 |
showToast('Success', `Refinement complete (${response.model_id})`, 'success');
|
| 243 |
|
|
|
|
|
|
|
|
|
|
| 244 |
} catch (error) {
|
| 245 |
hideLoading();
|
| 246 |
console.error('Refinement failed:', error);
|
|
@@ -248,6 +275,324 @@ async function runRefinement() {
|
|
| 248 |
}
|
| 249 |
}
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
// ============ Tool Management ============
|
| 252 |
|
| 253 |
function enableTools(enabled) {
|
|
@@ -256,6 +601,10 @@ function enableTools(enabled) {
|
|
| 256 |
document.getElementById('tool-clear').disabled = !enabled;
|
| 257 |
document.getElementById('run-segment').disabled = !enabled;
|
| 258 |
document.getElementById('run-refine').disabled = !enabled;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
}
|
| 260 |
|
| 261 |
function setActiveTool(tool) {
|
|
@@ -420,6 +769,25 @@ function setupEventListeners() {
|
|
| 420 |
document.getElementById('run-segment').addEventListener('click', runSegmentation);
|
| 421 |
document.getElementById('run-refine').addEventListener('click', runRefinement);
|
| 422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
// Display sliders
|
| 424 |
document.getElementById('window-slider').addEventListener('input', updateWindowValue);
|
| 425 |
document.getElementById('level-slider').addEventListener('input', updateLevelValue);
|
|
@@ -447,6 +815,10 @@ async function initialize() {
|
|
| 447 |
mprViewer.initialize();
|
| 448 |
console.log('MPR viewer initialized');
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
// Check backend connection
|
| 451 |
const connected = await checkConnection();
|
| 452 |
console.log('Backend connection:', connected);
|
|
|
|
| 64 |
const select = document.getElementById('model-select');
|
| 65 |
select.innerHTML = '';
|
| 66 |
|
| 67 |
+
// Find the 3D U-Net baseline model to set as default
|
| 68 |
+
const defaultModelId = 'unet3d-brain-tumor';
|
| 69 |
+
let defaultFound = false;
|
| 70 |
+
|
| 71 |
+
models.forEach((model) => {
|
| 72 |
const option = document.createElement('option');
|
| 73 |
// Handle both 'id' and 'model_id' from API
|
| 74 |
const modelId = model.id || model.model_id;
|
|
|
|
| 76 |
option.textContent = model.display_name;
|
| 77 |
select.appendChild(option);
|
| 78 |
|
| 79 |
+
// Select 3D U-Net (Baseline) as default
|
| 80 |
+
if (modelId === defaultModelId) {
|
| 81 |
appState.set('selectedModel', modelId);
|
| 82 |
+
select.value = modelId;
|
| 83 |
+
defaultFound = true;
|
| 84 |
}
|
| 85 |
});
|
| 86 |
|
| 87 |
+
// Fallback to first model if default not found
|
| 88 |
+
if (!defaultFound && models.length > 0) {
|
| 89 |
+
const firstModelId = models[0].id || models[0].model_id;
|
| 90 |
+
appState.set('selectedModel', firstModelId);
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
select.disabled = false;
|
| 94 |
updateModelInfo();
|
| 95 |
} catch (error) {
|
|
|
|
| 116 |
showLoading('Uploading volume...');
|
| 117 |
|
| 118 |
try {
|
| 119 |
+
// Reset previous state (GT, metrics, overlap viewer)
|
| 120 |
+
resetMaskAndMetrics();
|
| 121 |
+
|
| 122 |
// Upload to backend
|
| 123 |
const response = await apiClient.uploadVolume(file);
|
| 124 |
|
|
|
|
| 209 |
const shape = appState.get('volumeShape');
|
| 210 |
const spacing = appState.get('volumeSpacing');
|
| 211 |
mprViewer.loadMask(maskData, shape, spacing);
|
| 212 |
+
|
| 213 |
+
// Update overlap visualization
|
| 214 |
+
renderOverlapCanvas();
|
| 215 |
}
|
| 216 |
|
| 217 |
hideLoading();
|
| 218 |
showToast('Success', `Segmentation complete (${response.model_id})`, 'success');
|
| 219 |
|
| 220 |
+
// Fetch and display metrics
|
| 221 |
+
await updateMetricsDisplay();
|
| 222 |
+
|
| 223 |
} catch (error) {
|
| 224 |
hideLoading();
|
| 225 |
console.error('Segmentation failed:', error);
|
|
|
|
| 257 |
const shape = appState.get('volumeShape');
|
| 258 |
const spacing = appState.get('volumeSpacing');
|
| 259 |
mprViewer.loadMask(maskData, shape, spacing);
|
| 260 |
+
|
| 261 |
+
// Update overlap visualization
|
| 262 |
+
renderOverlapCanvas();
|
| 263 |
}
|
| 264 |
|
| 265 |
hideLoading();
|
| 266 |
showToast('Success', `Refinement complete (${response.model_id})`, 'success');
|
| 267 |
|
| 268 |
+
// Fetch and display metrics
|
| 269 |
+
await updateMetricsDisplay();
|
| 270 |
+
|
| 271 |
} catch (error) {
|
| 272 |
hideLoading();
|
| 273 |
console.error('Refinement failed:', error);
|
|
|
|
| 275 |
}
|
| 276 |
}
|
| 277 |
|
| 278 |
+
// ============ Ground Truth & Metrics ============
|
| 279 |
+
|
| 280 |
+
async function handleGroundTruthUpload(file) {
|
| 281 |
+
const volumeId = appState.get('volumeId');
|
| 282 |
+
if (!volumeId) {
|
| 283 |
+
showToast('Warning', 'Please upload a volume first', 'info');
|
| 284 |
+
return;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
showLoading('Uploading ground truth...');
|
| 288 |
+
|
| 289 |
+
try {
|
| 290 |
+
const response = await apiClient.uploadGroundTruth(volumeId, file);
|
| 291 |
+
|
| 292 |
+
// Update GT status
|
| 293 |
+
document.getElementById('gt-status').textContent = 'Loaded ✓';
|
| 294 |
+
document.getElementById('gt-status').style.color = 'var(--success-color)';
|
| 295 |
+
|
| 296 |
+
// Fetch the GT mask data for overlap visualization
|
| 297 |
+
const gtArrayBuffer = await apiClient.getGroundTruthData(volumeId);
|
| 298 |
+
const gtMaskData = new Uint8Array(gtArrayBuffer);
|
| 299 |
+
appState.setGroundTruth(gtMaskData);
|
| 300 |
+
|
| 301 |
+
// Render overlap canvas
|
| 302 |
+
renderOverlapCanvas();
|
| 303 |
+
|
| 304 |
+
hideLoading();
|
| 305 |
+
showToast('Success', `Ground truth loaded: ${response.gt_shape.join('×')}`, 'success');
|
| 306 |
+
|
| 307 |
+
// If we have a mask, update metrics to show Dice
|
| 308 |
+
if (appState.get('maskData')) {
|
| 309 |
+
await updateMetricsDisplay();
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
} catch (error) {
|
| 313 |
+
hideLoading();
|
| 314 |
+
console.error('Ground truth upload failed:', error);
|
| 315 |
+
showToast('Error', error.message, 'error');
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
async function updateMetricsDisplay() {
|
| 320 |
+
const volumeId = appState.get('volumeId');
|
| 321 |
+
if (!volumeId) return;
|
| 322 |
+
|
| 323 |
+
try {
|
| 324 |
+
const metrics = await apiClient.getMetrics(volumeId);
|
| 325 |
+
|
| 326 |
+
// Update volume display
|
| 327 |
+
const volumeEl = document.getElementById('mask-volume');
|
| 328 |
+
if (volumeEl) {
|
| 329 |
+
volumeEl.textContent = `${metrics.volume_ml.toFixed(2)} mL`;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
// Update Dice metrics if ground truth is available
|
| 333 |
+
if (metrics.has_ground_truth) {
|
| 334 |
+
document.getElementById('gt-status').textContent = 'Loaded ✓';
|
| 335 |
+
document.getElementById('gt-status').style.color = 'var(--success-color)';
|
| 336 |
+
|
| 337 |
+
// Dice score with color coding
|
| 338 |
+
const diceEl = document.getElementById('dice-score');
|
| 339 |
+
if (diceEl && metrics.dice_score !== null) {
|
| 340 |
+
const dice = metrics.dice_score;
|
| 341 |
+
diceEl.textContent = dice.toFixed(4);
|
| 342 |
+
// Color code: green > 0.7, yellow 0.5-0.7, red < 0.5
|
| 343 |
+
if (dice >= 0.7) {
|
| 344 |
+
diceEl.style.color = 'var(--success-color)';
|
| 345 |
+
} else if (dice >= 0.5) {
|
| 346 |
+
diceEl.style.color = 'var(--warning-color)';
|
| 347 |
+
} else {
|
| 348 |
+
diceEl.style.color = 'var(--error-color)';
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
const iouEl = document.getElementById('iou-score');
|
| 353 |
+
if (iouEl && metrics.iou_score !== null) {
|
| 354 |
+
iouEl.textContent = metrics.iou_score.toFixed(4);
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
const sensEl = document.getElementById('sensitivity');
|
| 358 |
+
if (sensEl && metrics.sensitivity !== null) {
|
| 359 |
+
sensEl.textContent = metrics.sensitivity.toFixed(4);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
const precEl = document.getElementById('precision-score');
|
| 363 |
+
if (precEl && metrics.precision !== null) {
|
| 364 |
+
precEl.textContent = metrics.precision.toFixed(4);
|
| 365 |
+
}
|
| 366 |
+
} else {
|
| 367 |
+
// No ground truth
|
| 368 |
+
document.getElementById('dice-score').textContent = '-';
|
| 369 |
+
document.getElementById('iou-score').textContent = '-';
|
| 370 |
+
document.getElementById('sensitivity').textContent = '-';
|
| 371 |
+
document.getElementById('precision-score').textContent = '-';
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
} catch (error) {
|
| 375 |
+
console.error('Failed to fetch metrics:', error);
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
/**
|
| 380 |
+
* Reset mask, ground truth, metrics display, and overlap viewer
|
| 381 |
+
* Called when uploading a new volume
|
| 382 |
+
*/
|
| 383 |
+
function resetMaskAndMetrics() {
|
| 384 |
+
// Clear state
|
| 385 |
+
appState.clearMask();
|
| 386 |
+
appState.clearGroundTruth();
|
| 387 |
+
|
| 388 |
+
// Reset GT status indicator
|
| 389 |
+
const gtStatus = document.getElementById('gt-status');
|
| 390 |
+
if (gtStatus) {
|
| 391 |
+
gtStatus.textContent = 'Not loaded';
|
| 392 |
+
gtStatus.style.color = '';
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
// Reset metrics display
|
| 396 |
+
const volumeEl = document.getElementById('mask-volume');
|
| 397 |
+
if (volumeEl) volumeEl.textContent = '-';
|
| 398 |
+
|
| 399 |
+
const diceEl = document.getElementById('dice-score');
|
| 400 |
+
if (diceEl) {
|
| 401 |
+
diceEl.textContent = '-';
|
| 402 |
+
diceEl.style.color = '';
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
const iouEl = document.getElementById('iou-score');
|
| 406 |
+
if (iouEl) iouEl.textContent = '-';
|
| 407 |
+
|
| 408 |
+
const sensEl = document.getElementById('sensitivity');
|
| 409 |
+
if (sensEl) sensEl.textContent = '-';
|
| 410 |
+
|
| 411 |
+
const precEl = document.getElementById('precision-score');
|
| 412 |
+
if (precEl) precEl.textContent = '-';
|
| 413 |
+
|
| 414 |
+
// Reset overlap viewer
|
| 415 |
+
const placeholder = document.getElementById('overlap-placeholder');
|
| 416 |
+
if (placeholder) placeholder.style.display = '';
|
| 417 |
+
|
| 418 |
+
const sliceIndicator = document.getElementById('overlap-slice-indicator');
|
| 419 |
+
if (sliceIndicator) sliceIndicator.style.display = 'none';
|
| 420 |
+
|
| 421 |
+
const sliceLabel = document.getElementById('overlap-slice-num');
|
| 422 |
+
if (sliceLabel) sliceLabel.textContent = '-';
|
| 423 |
+
|
| 424 |
+
// Clear overlap canvas
|
| 425 |
+
const canvas = document.getElementById('overlap-canvas');
|
| 426 |
+
if (canvas) {
|
| 427 |
+
const ctx = canvas.getContext('2d');
|
| 428 |
+
ctx.fillStyle = '#1a1a2e';
|
| 429 |
+
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// Clear mask from viewer
|
| 433 |
+
mprViewer.clearMask();
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
// ============ Overlap Visualization ============
|
| 437 |
+
|
| 438 |
+
/**
|
| 439 |
+
* Render overlap canvas showing TP/FP/FN comparison between prediction and ground truth
|
| 440 |
+
* Colors: TP=green, FP=red, FN=blue, Pred-only (no GT)=orange
|
| 441 |
+
*/
|
| 442 |
+
function renderOverlapCanvas() {
|
| 443 |
+
const canvas = document.getElementById('overlap-canvas');
|
| 444 |
+
const container = document.getElementById('overlap-canvas-container');
|
| 445 |
+
if (!canvas || !container) return;
|
| 446 |
+
|
| 447 |
+
const ctx = canvas.getContext('2d');
|
| 448 |
+
const maskData = appState.get('maskData');
|
| 449 |
+
const gtMaskData = appState.get('gtMaskData');
|
| 450 |
+
const volumeShape = appState.get('volumeShape');
|
| 451 |
+
|
| 452 |
+
// Get container size and set canvas to match
|
| 453 |
+
const rect = container.getBoundingClientRect();
|
| 454 |
+
const width = Math.floor(rect.width);
|
| 455 |
+
const height = Math.floor(rect.height);
|
| 456 |
+
|
| 457 |
+
if (width === 0 || height === 0) return;
|
| 458 |
+
|
| 459 |
+
canvas.width = width;
|
| 460 |
+
canvas.height = height;
|
| 461 |
+
|
| 462 |
+
// Clear canvas
|
| 463 |
+
ctx.fillStyle = '#1a1a2e';
|
| 464 |
+
ctx.fillRect(0, 0, width, height);
|
| 465 |
+
|
| 466 |
+
// Check if we have any data to display
|
| 467 |
+
if (!volumeShape || !maskData) {
|
| 468 |
+
ctx.fillStyle = '#666';
|
| 469 |
+
ctx.font = '12px sans-serif';
|
| 470 |
+
ctx.textAlign = 'center';
|
| 471 |
+
ctx.fillText('No segmentation loaded', width / 2, height / 2);
|
| 472 |
+
return;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
const [D, H, W] = volumeShape; // D=depth, H=height, W=width
|
| 476 |
+
const sliceIndex = appState.get('overlapSliceIndex');
|
| 477 |
+
|
| 478 |
+
// Create image data at native resolution
|
| 479 |
+
const imageData = ctx.createImageData(D, H);
|
| 480 |
+
const pixels = imageData.data;
|
| 481 |
+
|
| 482 |
+
// Color definitions (RGBA)
|
| 483 |
+
const colors = {
|
| 484 |
+
TP: [76, 175, 80, 220], // Green - True Positive
|
| 485 |
+
FP: [244, 67, 54, 220], // Red - False Positive
|
| 486 |
+
FN: [33, 150, 243, 220], // Blue - False Negative
|
| 487 |
+
PRED_ONLY: [255, 152, 0, 220], // Orange - Prediction only (no GT available)
|
| 488 |
+
BG: [26, 26, 46, 255] // Dark background
|
| 489 |
+
};
|
| 490 |
+
|
| 491 |
+
// Iterate through slice pixels
|
| 492 |
+
for (let d = 0; d < D; d++) {
|
| 493 |
+
for (let h = 0; h < H; h++) {
|
| 494 |
+
// Volume index: (d, h, w) where w is the slice index
|
| 495 |
+
const volIdx = d * H * W + h * W + sliceIndex;
|
| 496 |
+
|
| 497 |
+
const pred = maskData[volIdx] > 0;
|
| 498 |
+
const gt = gtMaskData ? gtMaskData[volIdx] > 0 : null;
|
| 499 |
+
|
| 500 |
+
// Canvas pixel index: transposed for correct orientation
|
| 501 |
+
const canvasIdx = (h * D + d) * 4;
|
| 502 |
+
|
| 503 |
+
let color;
|
| 504 |
+
if (gt === null) {
|
| 505 |
+
// No ground truth available
|
| 506 |
+
color = pred ? colors.PRED_ONLY : colors.BG;
|
| 507 |
+
} else {
|
| 508 |
+
// Ground truth available - show TP/FP/FN
|
| 509 |
+
if (pred && gt) {
|
| 510 |
+
color = colors.TP; // True Positive
|
| 511 |
+
} else if (pred && !gt) {
|
| 512 |
+
color = colors.FP; // False Positive
|
| 513 |
+
} else if (!pred && gt) {
|
| 514 |
+
color = colors.FN; // False Negative
|
| 515 |
+
} else {
|
| 516 |
+
color = colors.BG; // True Negative (background)
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
pixels[canvasIdx] = color[0]; // R
|
| 521 |
+
pixels[canvasIdx + 1] = color[1]; // G
|
| 522 |
+
pixels[canvasIdx + 2] = color[2]; // B
|
| 523 |
+
pixels[canvasIdx + 3] = color[3]; // A
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
// Create temp canvas for scaling
|
| 528 |
+
const tempCanvas = document.createElement('canvas');
|
| 529 |
+
tempCanvas.width = D;
|
| 530 |
+
tempCanvas.height = H;
|
| 531 |
+
tempCanvas.getContext('2d').putImageData(imageData, 0, 0);
|
| 532 |
+
|
| 533 |
+
// Scale to fit canvas while maintaining aspect ratio
|
| 534 |
+
const aspect = D / H;
|
| 535 |
+
const canvasAspect = width / height;
|
| 536 |
+
|
| 537 |
+
let displayW, displayH, offsetX, offsetY;
|
| 538 |
+
if (canvasAspect > aspect) {
|
| 539 |
+
displayH = height;
|
| 540 |
+
displayW = displayH * aspect;
|
| 541 |
+
offsetX = (width - displayW) / 2;
|
| 542 |
+
offsetY = 0;
|
| 543 |
+
} else {
|
| 544 |
+
displayW = width;
|
| 545 |
+
displayH = displayW / aspect;
|
| 546 |
+
offsetX = 0;
|
| 547 |
+
offsetY = (height - displayH) / 2;
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
ctx.imageSmoothingEnabled = false; // Nearest neighbor for mask
|
| 551 |
+
ctx.drawImage(tempCanvas, offsetX, offsetY, displayW, displayH);
|
| 552 |
+
|
| 553 |
+
// Update slice indicator
|
| 554 |
+
const sliceIndicator = document.getElementById('overlap-slice-indicator');
|
| 555 |
+
const sliceLabel = document.getElementById('overlap-slice-num');
|
| 556 |
+
if (sliceLabel) {
|
| 557 |
+
sliceLabel.textContent = `${sliceIndex + 1} / ${W}`;
|
| 558 |
+
}
|
| 559 |
+
if (sliceIndicator) {
|
| 560 |
+
sliceIndicator.style.display = 'block';
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
// Hide placeholder
|
| 564 |
+
const placeholder = document.getElementById('overlap-placeholder');
|
| 565 |
+
if (placeholder) placeholder.style.display = 'none';
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
/**
|
| 569 |
+
* Initialize overlap viewer with scroll support
|
| 570 |
+
*/
|
| 571 |
+
function initOverlapViewer() {
|
| 572 |
+
const container = document.getElementById('overlap-canvas-container');
|
| 573 |
+
if (!container) return;
|
| 574 |
+
|
| 575 |
+
// Mouse wheel scrolling
|
| 576 |
+
container.addEventListener('wheel', (e) => {
|
| 577 |
+
e.preventDefault();
|
| 578 |
+
const maskData = appState.get('maskData');
|
| 579 |
+
if (!maskData) return;
|
| 580 |
+
|
| 581 |
+
const delta = e.deltaY > 0 ? 1 : -1;
|
| 582 |
+
const currentSlice = appState.get('overlapSliceIndex');
|
| 583 |
+
appState.setOverlapSlice(currentSlice + delta);
|
| 584 |
+
renderOverlapCanvas();
|
| 585 |
+
}, { passive: false });
|
| 586 |
+
|
| 587 |
+
// Resize observer for responsive canvas
|
| 588 |
+
const resizeObserver = new ResizeObserver(() => {
|
| 589 |
+
if (appState.get('maskData')) {
|
| 590 |
+
renderOverlapCanvas();
|
| 591 |
+
}
|
| 592 |
+
});
|
| 593 |
+
resizeObserver.observe(container);
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
// ============ Tool Management ============
|
| 597 |
|
| 598 |
function enableTools(enabled) {
|
|
|
|
| 601 |
document.getElementById('tool-clear').disabled = !enabled;
|
| 602 |
document.getElementById('run-segment').disabled = !enabled;
|
| 603 |
document.getElementById('run-refine').disabled = !enabled;
|
| 604 |
+
|
| 605 |
+
// Enable ground truth upload when volume is loaded
|
| 606 |
+
const gtBtn = document.getElementById('upload-gt-btn');
|
| 607 |
+
if (gtBtn) gtBtn.disabled = !enabled;
|
| 608 |
}
|
| 609 |
|
| 610 |
function setActiveTool(tool) {
|
|
|
|
| 769 |
document.getElementById('run-segment').addEventListener('click', runSegmentation);
|
| 770 |
document.getElementById('run-refine').addEventListener('click', runRefinement);
|
| 771 |
|
| 772 |
+
// Ground truth upload
|
| 773 |
+
const gtUploadBtn = document.getElementById('upload-gt-btn');
|
| 774 |
+
const gtFileInput = document.getElementById('gt-upload-input');
|
| 775 |
+
|
| 776 |
+
if (gtUploadBtn && gtFileInput) {
|
| 777 |
+
gtUploadBtn.addEventListener('click', () => {
|
| 778 |
+
gtFileInput.click();
|
| 779 |
+
});
|
| 780 |
+
|
| 781 |
+
gtFileInput.addEventListener('change', (e) => {
|
| 782 |
+
const file = e.target.files[0];
|
| 783 |
+
if (file) {
|
| 784 |
+
handleGroundTruthUpload(file);
|
| 785 |
+
}
|
| 786 |
+
// Reset input so same file can be selected again
|
| 787 |
+
e.target.value = '';
|
| 788 |
+
});
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
// Display sliders
|
| 792 |
document.getElementById('window-slider').addEventListener('input', updateWindowValue);
|
| 793 |
document.getElementById('level-slider').addEventListener('input', updateLevelValue);
|
|
|
|
| 815 |
mprViewer.initialize();
|
| 816 |
console.log('MPR viewer initialized');
|
| 817 |
|
| 818 |
+
// Initialize overlap viewer (scroll support)
|
| 819 |
+
initOverlapViewer();
|
| 820 |
+
console.log('Overlap viewer initialized');
|
| 821 |
+
|
| 822 |
// Check backend connection
|
| 823 |
const connected = await checkConnection();
|
| 824 |
console.log('Backend connection:', connected);
|
seg_app/ui_slicer/index.html
CHANGED
|
@@ -137,27 +137,77 @@
|
|
| 137 |
<div class="viewer-header">
|
| 138 |
<span class="view-label info">INFO</span>
|
| 139 |
</div>
|
| 140 |
-
<div class="info-panel" id="info-view">
|
| 141 |
-
<
|
| 142 |
-
|
| 143 |
-
<
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
</div>
|
| 148 |
-
<
|
| 149 |
-
|
| 150 |
-
<
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
</div>
|
| 162 |
</div>
|
| 163 |
</div>
|
|
|
|
| 137 |
<div class="viewer-header">
|
| 138 |
<span class="view-label info">INFO</span>
|
| 139 |
</div>
|
| 140 |
+
<div class="info-panel-split" id="info-view">
|
| 141 |
+
<!-- Left Column: Metrics & Info -->
|
| 142 |
+
<div class="info-left">
|
| 143 |
+
<div class="crosshair-info">
|
| 144 |
+
<h4>📍 Crosshair Position</h4>
|
| 145 |
+
<p>X: <span id="pos-x">-</span></p>
|
| 146 |
+
<p>Y: <span id="pos-y">-</span></p>
|
| 147 |
+
<p>Z: <span id="pos-z">-</span></p>
|
| 148 |
+
<p>Value: <span id="pos-value">-</span></p>
|
| 149 |
+
</div>
|
| 150 |
+
<div class="seg-info">
|
| 151 |
+
<h4>📊 Segmentation</h4>
|
| 152 |
+
<p>Mask loaded: <span id="mask-status">No</span></p>
|
| 153 |
+
<p>Voxels: <span id="mask-voxels">-</span></p>
|
| 154 |
+
<p>Volume: <span id="mask-volume">-</span></p>
|
| 155 |
+
</div>
|
| 156 |
+
<div class="dice-info">
|
| 157 |
+
<h4>🎯 Dice Metrics</h4>
|
| 158 |
+
<p>Ground Truth: <span id="gt-status">Not loaded</span></p>
|
| 159 |
+
<p>Dice Score: <span id="dice-score">-</span></p>
|
| 160 |
+
<p>IoU Score: <span id="iou-score">-</span></p>
|
| 161 |
+
<p>Sensitivity: <span id="sensitivity">-</span></p>
|
| 162 |
+
<p>Precision: <span id="precision-score">-</span></p>
|
| 163 |
+
<input type="file" id="gt-upload-input" accept=".nii,.nii.gz" style="display: none;">
|
| 164 |
+
<button id="upload-gt-btn" class="btn btn-secondary btn-small" disabled>
|
| 165 |
+
📤 Upload Ground Truth
|
| 166 |
+
</button>
|
| 167 |
+
</div>
|
| 168 |
+
<div class="keyboard-shortcuts">
|
| 169 |
+
<h4>⌨️ Shortcuts</h4>
|
| 170 |
+
<p><kbd>Scroll</kbd> Change slice</p>
|
| 171 |
+
<p><kbd>Click</kbd> Add prompt</p>
|
| 172 |
+
<p><kbd>P</kbd> Positive mode</p>
|
| 173 |
+
<p><kbd>N</kbd> Negative mode</p>
|
| 174 |
+
<p><kbd>C</kbd> Clear prompts</p>
|
| 175 |
+
<p><kbd>R</kbd> Run segmentation</p>
|
| 176 |
+
</div>
|
| 177 |
</div>
|
| 178 |
+
<!-- Right Column: Overlap Viewer -->
|
| 179 |
+
<div class="info-right">
|
| 180 |
+
<div class="overlap-viewer">
|
| 181 |
+
<h4>🔍 Mask Comparison</h4>
|
| 182 |
+
<div class="overlap-canvas-container" id="overlap-canvas-container">
|
| 183 |
+
<canvas id="overlap-canvas"></canvas>
|
| 184 |
+
<div class="slice-indicator" id="overlap-slice-indicator">
|
| 185 |
+
<span id="overlap-slice-num">-</span>
|
| 186 |
+
</div>
|
| 187 |
+
<p class="overlap-placeholder" id="overlap-placeholder">
|
| 188 |
+
Run segmentation to view mask<br>
|
| 189 |
+
<small>Scroll to change slice</small>
|
| 190 |
+
</p>
|
| 191 |
+
</div>
|
| 192 |
+
<div class="overlap-legend">
|
| 193 |
+
<div class="legend-item">
|
| 194 |
+
<span class="legend-color tp"></span>
|
| 195 |
+
<span>TP (Both agree)</span>
|
| 196 |
+
</div>
|
| 197 |
+
<div class="legend-item">
|
| 198 |
+
<span class="legend-color fp"></span>
|
| 199 |
+
<span>FP (Pred only)</span>
|
| 200 |
+
</div>
|
| 201 |
+
<div class="legend-item">
|
| 202 |
+
<span class="legend-color fn"></span>
|
| 203 |
+
<span>FN (GT only)</span>
|
| 204 |
+
</div>
|
| 205 |
+
<div class="legend-item">
|
| 206 |
+
<span class="legend-color pred-only"></span>
|
| 207 |
+
<span>Prediction</span>
|
| 208 |
+
</div>
|
| 209 |
+
</div>
|
| 210 |
+
</div>
|
| 211 |
</div>
|
| 212 |
</div>
|
| 213 |
</div>
|
seg_app/ui_slicer/state.js
CHANGED
|
@@ -65,6 +65,11 @@ class AppState {
|
|
| 65 |
maskLoaded: false,
|
| 66 |
maskData: null, // Uint8Array
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
// Display settings
|
| 69 |
windowWidth: 400,
|
| 70 |
windowLevel: 40,
|
|
@@ -215,6 +220,9 @@ class AppState {
|
|
| 215 |
volumeData: null,
|
| 216 |
maskLoaded: false,
|
| 217 |
maskData: null,
|
|
|
|
|
|
|
|
|
|
| 218 |
prompts: []
|
| 219 |
});
|
| 220 |
}
|
|
@@ -330,6 +338,47 @@ class AppState {
|
|
| 330 |
});
|
| 331 |
}
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
// ============ Loading State ============
|
| 334 |
|
| 335 |
/**
|
|
|
|
| 65 |
maskLoaded: false,
|
| 66 |
maskData: null, // Uint8Array
|
| 67 |
|
| 68 |
+
// Ground Truth (for Dice comparison)
|
| 69 |
+
gtLoaded: false,
|
| 70 |
+
gtMaskData: null, // Uint8Array
|
| 71 |
+
overlapSliceIndex: 0, // Slice index for overlap viewer
|
| 72 |
+
|
| 73 |
// Display settings
|
| 74 |
windowWidth: 400,
|
| 75 |
windowLevel: 40,
|
|
|
|
| 220 |
volumeData: null,
|
| 221 |
maskLoaded: false,
|
| 222 |
maskData: null,
|
| 223 |
+
gtLoaded: false,
|
| 224 |
+
gtMaskData: null,
|
| 225 |
+
overlapSliceIndex: 0,
|
| 226 |
prompts: []
|
| 227 |
});
|
| 228 |
}
|
|
|
|
| 338 |
});
|
| 339 |
}
|
| 340 |
|
| 341 |
+
// ============ Ground Truth Methods ============
|
| 342 |
+
|
| 343 |
+
/**
|
| 344 |
+
* Set ground truth mask data
|
| 345 |
+
* @param {Uint8Array} data
|
| 346 |
+
*/
|
| 347 |
+
setGroundTruth(data) {
|
| 348 |
+
// Set overlap slice to center initially
|
| 349 |
+
const shape = this.get('volumeShape');
|
| 350 |
+
const centerSlice = shape ? Math.floor(shape[2] / 2) : 0;
|
| 351 |
+
|
| 352 |
+
this.update({
|
| 353 |
+
gtLoaded: true,
|
| 354 |
+
gtMaskData: data,
|
| 355 |
+
overlapSliceIndex: centerSlice
|
| 356 |
+
});
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
/**
|
| 360 |
+
* Clear ground truth data
|
| 361 |
+
*/
|
| 362 |
+
clearGroundTruth() {
|
| 363 |
+
this.update({
|
| 364 |
+
gtLoaded: false,
|
| 365 |
+
gtMaskData: null,
|
| 366 |
+
overlapSliceIndex: 0
|
| 367 |
+
});
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
/**
|
| 371 |
+
* Set overlap viewer slice index
|
| 372 |
+
* @param {number} index
|
| 373 |
+
*/
|
| 374 |
+
setOverlapSlice(index) {
|
| 375 |
+
const shape = this.get('volumeShape');
|
| 376 |
+
if (!shape) return;
|
| 377 |
+
const maxIndex = shape[2] - 1; // W dimension (axial slices through Z)
|
| 378 |
+
const clampedIndex = Math.max(0, Math.min(index, maxIndex));
|
| 379 |
+
this.set('overlapSliceIndex', clampedIndex);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
// ============ Loading State ============
|
| 383 |
|
| 384 |
/**
|
seg_app/ui_slicer/styles.css
CHANGED
|
@@ -17,6 +17,11 @@
|
|
| 17 |
--border-color: #2a2a4a;
|
| 18 |
--shadow-color: rgba(0, 0, 0, 0.3);
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
/* View colors */
|
| 21 |
--axial-color: #4caf50;
|
| 22 |
--sagittal-color: #2196f3;
|
|
@@ -484,46 +489,178 @@ body {
|
|
| 484 |
cursor: crosshair;
|
| 485 |
}
|
| 486 |
|
| 487 |
-
/* Info Panel (4th quadrant) */
|
| 488 |
-
.info-panel {
|
| 489 |
flex: 1;
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
overflow-y: auto;
|
|
|
|
|
|
|
| 492 |
}
|
| 493 |
|
| 494 |
-
.info-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
|
|
|
|
|
|
| 498 |
}
|
| 499 |
|
| 500 |
-
.info-
|
|
|
|
| 501 |
font-size: 0.75rem;
|
| 502 |
-
color: var(--text-secondary);
|
| 503 |
margin-bottom: var(--spacing-xs);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
}
|
| 505 |
|
| 506 |
.crosshair-info,
|
| 507 |
.seg-info,
|
|
|
|
| 508 |
.keyboard-shortcuts {
|
| 509 |
-
margin-bottom: var(--spacing-
|
| 510 |
-
padding-bottom: var(--spacing-
|
| 511 |
border-bottom: 1px solid var(--border-color);
|
| 512 |
}
|
| 513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
.keyboard-shortcuts:last-child {
|
| 515 |
border-bottom: none;
|
| 516 |
}
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
kbd {
|
| 519 |
display: inline-block;
|
| 520 |
-
padding:
|
| 521 |
-
font-size: 0.
|
| 522 |
font-family: monospace;
|
| 523 |
background: var(--bg-tertiary);
|
| 524 |
border: 1px solid var(--border-color);
|
| 525 |
border-radius: 3px;
|
| 526 |
-
min-width:
|
| 527 |
text-align: center;
|
| 528 |
}
|
| 529 |
|
|
|
|
| 17 |
--border-color: #2a2a4a;
|
| 18 |
--shadow-color: rgba(0, 0, 0, 0.3);
|
| 19 |
|
| 20 |
+
/* Semantic colors for metrics */
|
| 21 |
+
--success-color: #4caf50;
|
| 22 |
+
--warning-color: #ff9800;
|
| 23 |
+
--error-color: #f44336;
|
| 24 |
+
|
| 25 |
/* View colors */
|
| 26 |
--axial-color: #4caf50;
|
| 27 |
--sagittal-color: #2196f3;
|
|
|
|
| 489 |
cursor: crosshair;
|
| 490 |
}
|
| 491 |
|
| 492 |
+
/* Info Panel (4th quadrant) - Split Layout */
|
| 493 |
+
.info-panel-split {
|
| 494 |
flex: 1;
|
| 495 |
+
display: flex;
|
| 496 |
+
gap: var(--spacing-sm);
|
| 497 |
+
padding: var(--spacing-sm);
|
| 498 |
+
overflow: hidden;
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
.info-left {
|
| 502 |
+
flex: 1;
|
| 503 |
+
min-width: 0;
|
| 504 |
overflow-y: auto;
|
| 505 |
+
padding-right: var(--spacing-sm);
|
| 506 |
+
border-right: 1px solid var(--border-color);
|
| 507 |
}
|
| 508 |
|
| 509 |
+
.info-right {
|
| 510 |
+
flex: 1;
|
| 511 |
+
min-width: 0;
|
| 512 |
+
display: flex;
|
| 513 |
+
flex-direction: column;
|
| 514 |
+
overflow-y: auto;
|
| 515 |
}
|
| 516 |
|
| 517 |
+
.info-left h4,
|
| 518 |
+
.info-right h4 {
|
| 519 |
font-size: 0.75rem;
|
|
|
|
| 520 |
margin-bottom: var(--spacing-xs);
|
| 521 |
+
color: var(--text-primary);
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
.info-left p {
|
| 525 |
+
font-size: 0.7rem;
|
| 526 |
+
color: var(--text-secondary);
|
| 527 |
+
margin-bottom: 2px;
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
/* Overlap Viewer */
|
| 531 |
+
.overlap-viewer {
|
| 532 |
+
display: flex;
|
| 533 |
+
flex-direction: column;
|
| 534 |
+
height: 100%;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
.overlap-canvas-container {
|
| 538 |
+
position: relative;
|
| 539 |
+
flex: 1;
|
| 540 |
+
min-height: 160px;
|
| 541 |
+
background: var(--bg-primary);
|
| 542 |
+
border: 1px solid var(--border-color);
|
| 543 |
+
border-radius: var(--radius-sm);
|
| 544 |
+
display: flex;
|
| 545 |
+
align-items: center;
|
| 546 |
+
justify-content: center;
|
| 547 |
+
overflow: hidden;
|
| 548 |
+
cursor: ns-resize;
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
#overlap-canvas {
|
| 552 |
+
width: 100%;
|
| 553 |
+
height: 100%;
|
| 554 |
+
display: block;
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
.overlap-canvas-container .slice-indicator {
|
| 558 |
+
position: absolute;
|
| 559 |
+
bottom: 4px;
|
| 560 |
+
left: 50%;
|
| 561 |
+
transform: translateX(-50%);
|
| 562 |
+
background: rgba(0, 0, 0, 0.7);
|
| 563 |
+
color: var(--text-primary);
|
| 564 |
+
padding: 2px 8px;
|
| 565 |
+
border-radius: var(--radius-sm);
|
| 566 |
+
font-size: 0.65rem;
|
| 567 |
+
font-family: var(--font-mono);
|
| 568 |
+
pointer-events: none;
|
| 569 |
+
z-index: 10;
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
.overlap-placeholder {
|
| 573 |
+
position: absolute;
|
| 574 |
+
font-size: 0.7rem;
|
| 575 |
+
color: var(--text-secondary);
|
| 576 |
+
text-align: center;
|
| 577 |
+
padding: var(--spacing-sm);
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
.overlap-placeholder small {
|
| 581 |
+
font-size: 0.6rem;
|
| 582 |
+
opacity: 0.7;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
.overlap-placeholder.hidden {
|
| 586 |
+
display: none;
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
/* Overlap Legend */
|
| 590 |
+
.overlap-legend {
|
| 591 |
+
margin-top: var(--spacing-xs);
|
| 592 |
+
display: grid;
|
| 593 |
+
grid-template-columns: 1fr 1fr;
|
| 594 |
+
gap: 2px;
|
| 595 |
+
font-size: 0.6rem;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
.legend-item {
|
| 599 |
+
display: flex;
|
| 600 |
+
align-items: center;
|
| 601 |
+
gap: 4px;
|
| 602 |
+
color: var(--text-secondary);
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
.legend-color {
|
| 606 |
+
width: 10px;
|
| 607 |
+
height: 10px;
|
| 608 |
+
border-radius: 2px;
|
| 609 |
+
flex-shrink: 0;
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
.legend-color.tp {
|
| 613 |
+
background: #4caf50; /* Green - True Positive */
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
.legend-color.fp {
|
| 617 |
+
background: #f44336; /* Red - False Positive */
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
.legend-color.fn {
|
| 621 |
+
background: #2196f3; /* Blue - False Negative */
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
.legend-color.pred-only {
|
| 625 |
+
background: #ff9800; /* Orange - Prediction only (no GT) */
|
| 626 |
}
|
| 627 |
|
| 628 |
.crosshair-info,
|
| 629 |
.seg-info,
|
| 630 |
+
.dice-info,
|
| 631 |
.keyboard-shortcuts {
|
| 632 |
+
margin-bottom: var(--spacing-sm);
|
| 633 |
+
padding-bottom: var(--spacing-sm);
|
| 634 |
border-bottom: 1px solid var(--border-color);
|
| 635 |
}
|
| 636 |
|
| 637 |
+
.dice-info #gt-status {
|
| 638 |
+
font-weight: 500;
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
.dice-info .btn-small {
|
| 642 |
+
margin-top: var(--spacing-xs);
|
| 643 |
+
padding: 3px 6px;
|
| 644 |
+
font-size: 0.6rem;
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
.keyboard-shortcuts:last-child {
|
| 648 |
border-bottom: none;
|
| 649 |
}
|
| 650 |
|
| 651 |
+
.keyboard-shortcuts p {
|
| 652 |
+
font-size: 0.6rem;
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
kbd {
|
| 656 |
display: inline-block;
|
| 657 |
+
padding: 1px 4px;
|
| 658 |
+
font-size: 0.6rem;
|
| 659 |
font-family: monospace;
|
| 660 |
background: var(--bg-tertiary);
|
| 661 |
border: 1px solid var(--border-color);
|
| 662 |
border-radius: 3px;
|
| 663 |
+
min-width: 40px;
|
| 664 |
text-align: center;
|
| 665 |
}
|
| 666 |
|