LesionDetection / README.md
maregu2023's picture
deploy: add suggested_hardware t4-small, fix Dockerfile CUDA assertion
17c82aa
metadata
title: Brain Lesion Segmentation
emoji: 🧠
colorFrom: blue
colorTo: purple
sdk: docker
app_port: 7860
suggested_hardware: t4-small
pinned: false
license: mit
tags:
  - medical-imaging
  - segmentation
  - brain-lesion
  - 3d-visualization

🧠 Brain Lesion Segmentation

Interactive 3D brain lesion segmentation tool for MR volumes.

⚠️ Research Prototype

This is a research prototype and is NOT intended for clinical diagnosis or treatment decisions.

  • Not FDA/CE approved
  • For research and educational purposes only
  • Results must be validated by qualified medical professionals

Features

  • Multi-planar Visualization: View axial, sagittal, and coronal slices simultaneously
  • Two Model Options:
    • 3D U-Net (Baseline): Fast automatic segmentation
    • Medical SAM 3D (SA-Med3D-140K): Interactive refinement with point prompts
  • Volume Metrics: Automatic calculation of segmentation volume in mmΒ³

Usage

  1. Upload a NIfTI volume (.nii or .nii.gz)
  2. Select a model:
    • Use 3D U-Net for quick automatic segmentation
    • Use Medical SAM 3D with point prompts for interactive refinement
  3. Add prompts (SAM only): Click + for lesion, - for background
  4. Run Segmentation to generate the mask
  5. View overlays and volume metrics

Supported Formats

  • NIfTI-1 (.nii)
  • Compressed NIfTI (.nii.gz)

Models

Model Type Use Case
3D U-Net Automatic Fast baseline segmentation
SA-Med3D-140K Interactive Point-prompt refinement

⚠️ Disclaimer

This is a research prototype and is NOT intended for clinical diagnosis or treatment decisions.

  • Not FDA/CE approved
  • For research and educational purposes only
  • Results should be validated by qualified medical professionals

Technical Details

  • Framework: Gradio 5.x + PyTorch
  • Medical Imaging: nibabel for NIfTI I/O
  • 3D U-Net: MONAI-based architecture
  • SAM-Med3D: SA-Med3D-140K from Hugging Face Hub

Hardware Requirements

  • Recommended: T4 GPU or better
  • Minimum: 8GB GPU memory for 3D inference

Citation

If you use this tool in your research, please cite:

@article{wang2023sammed3d,
  title={SAM-Med3D},
  author={Wang, Haoyu and others},
  journal={arXiv preprint arXiv:2310.15161},
  year={2023}
}

License

MIT License - See LICENSE file for details.


πŸ–₯️ Running Locally

This project provides two web application interfaces:

  1. Gradio UI – A simple, server-rendered interface ideal for quick inference and demos
  2. VTK.js Slicer – A professional WebGL-based frontend with smooth MPR navigation (requires FastAPI backend)

Prerequisites

Before running either interface, ensure you have:

  1. Python 3.10+ installed
  2. Conda (recommended) or a Python virtual environment
  3. Git (to clone the repository)
  4. CUDA-capable GPU (recommended for inference, optional for testing)

Step 1: Clone the Repository

git clone https://github.com/your-org/web_app.git
cd web_app

Step 2: Create and Activate the Conda Environment

# Create a new conda environment
conda create -n seg_app python=3.10 -y

# Activate the environment
conda activate seg_app

Step 3: Install Dependencies

# Install PyTorch with CUDA support (adjust CUDA version as needed)
# Visit https://pytorch.org/get-started/locally/ for the correct command
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

# Install project dependencies
pip install -r requirements.txt

Step 4: Verify Installation

# Test that all imports work
python test_import.py

Option A: Running the Gradio UI (Recommended for Quick Start)

The Gradio interface is a single-command solution that runs entirely on the server side. Best for quick demos and users unfamiliar with multi-service setups.

Start the Gradio App

# Make sure you're in the web_app directory with conda environment activated
conda activate seg_app

# Run the Gradio app
python app.py

Access the Interface

Once started, you'll see output like:

Starting seg_app in local mode...
Running on local URL:  http://127.0.0.1:7869

Open your browser and navigate to: http://127.0.0.1:7869

Gradio UI Workflow

  1. Upload Volume: Click "Upload NIfTI" and select a .nii or .nii.gz file
  2. Select Model: Choose between:
    • 3D U-Net (Baseline): Fast automatic segmentation
    • Medical SAM 3D: Interactive refinement with point prompts
  3. Add Prompts (SAM only):
    • Use the coordinate inputs to specify (depth, height, width)
    • Click "+ Positive" for lesion points, "+ Negative" for background
  4. Run Segmentation: Click the "Run Segmentation" button
  5. View Results: Overlays appear on the multi-planar views with volume metrics

Stopping the Server

Press Ctrl+C in the terminal to stop the Gradio server.


Option B: Running the VTK.js Slicer (Professional Frontend)

The VTK.js Slicer provides a professional medical imaging interface with smooth scrolling, WebGL rendering, and click-to-place prompts. This requires running two services:

  1. FastAPI Backend (handles inference)
  2. Frontend Dev Server (serves the VTK.js UI)

Terminal 1: Start the FastAPI Backend

# Activate the conda environment
conda activate seg_app

# Navigate to the web_app directory
cd path\to\web_app

# Start the FastAPI backend with uvicorn
uvicorn seg_app.backend.api:app --reload --host 127.0.0.1 --port 8000

You should see output like:

INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     Started reloader process
INFO:     Started server process
INFO:     Waiting for application startup.
INFO:     Application startup complete.

Note: The --reload flag enables auto-reloading when code changes. Remove it for production.

Terminal 2: Start the Frontend Server

Open a new terminal (keep the backend running):

# Navigate to the ui_slicer directory
cd path\to\web_app\seg_app\ui_slicer

# Option A: Use the included Python server (recommended)
python serve_frontend.py

# Option B: Use npx serve (if you have Node.js installed)
npx serve . -p 5500

# Option C: Use VS Code Live Server extension
# Right-click on index.html β†’ "Open with Live Server"

The Python server will show:

πŸš€ VTK.js Slicer Frontend Server
========================================
Serving:  ...\seg_app\ui_slicer
URL:      http://localhost:5500
========================================

Press Ctrl+C to stop.

Access the VTK.js Interface

Open your browser and navigate to: http://localhost:5500

VTK.js Slicer Workflow

  1. Upload Volume: Click "Upload NIfTI" or drag-and-drop a .nii/.nii.gz file
  2. Navigate Slices: Use mouse wheel to scroll through slices in any view
  3. Select Prompt Tool (optional):
    • Press P for positive prompts (mark lesion)
    • Press N for negative prompts (mark background)
    • Click on any view to place prompts
  4. Run Segmentation: Click "Run Segmentation" or press R
  5. Refine: Add more prompts and click "Refine with Prompts" or press Shift+R
  6. Clear Prompts: Press C or click "Clear Prompts"

Keyboard Shortcuts

Key Action
P Positive prompt mode
N Negative prompt mode
C Clear all prompts
R Run segmentation
Shift+R Refine with prompts
Esc Cancel prompt mode

Stopping the Servers

  • Press Ctrl+C in each terminal to stop the respective server

Comparison: Gradio vs VTK.js Slicer

Feature Gradio UI VTK.js Slicer
Setup Complexity Single command Two services
Slice Navigation Button/slider based Smooth mouse wheel
Rendering Server-rendered PNG Client-side WebGL (60fps)
Prompt Placement Manual coordinate input Click on image
MPR Views Static images Synchronized, real-time
Best For Quick demos, remote access Power users, research

Troubleshooting

Common Issues

"Module not found" errors

# Ensure conda environment is activated
conda activate seg_app

# Reinstall dependencies
pip install -r requirements.txt

"CUDA out of memory" or slow inference

  • Reduce input volume size or use CPU inference
  • Close other GPU applications

Backend not connected (VTK.js Slicer)

  • Ensure the FastAPI backend is running on port 8000
  • Check browser console for CORS errors
  • Verify the backend URL in api-client.js matches your setup

Blank viewer in VTK.js

  • Check browser console for WebGL errors
  • Ensure volume upload completed successfully
  • Try a different browser (Chrome/Edge recommended)

Gradio app crashes on startup

  • Check for port conflicts: change port in app.py if 7869 is in use
  • Verify PyTorch installation: python -c "import torch; print(torch.cuda.is_available())"

Model Loading Issues

Models are lazy-loaded from Hugging Face Hub on first inference. If you experience issues:

# Clear Hugging Face cache and re-download
rm -r ~/.cache/huggingface/hub/models--*sam*

# Or manually download the model
python -c "from huggingface_hub import hf_hub_download; hf_hub_download('blueyo0/SAM-Med3D', 'sam_med3d.pth')"

API Documentation (Advanced)

When the FastAPI backend is running, you can access the interactive API documentation:

Key endpoints:

  • POST /volume/upload – Upload a NIfTI volume
  • POST /segment – Run segmentation on uploaded volume
  • POST /refine – Refine segmentation with additional prompts
  • GET /mask/{volume_id}/data – Download raw mask data
  • GET /health – Health check endpoint

🧬 Integrating Your Own nnUNet Model

This section explains how to integrate a trained nnUNet v2 model into the web application as a baseline model option.

Prerequisites

  1. Trained nnUNet model: You need a completed nnUNet v2 training run with:

    • Code location (for reference): e.g., /mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet/
    • Checkpoints directory: e.g., /mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/
  2. nnUNet v2 installed: The nnunetv2 package must be installed in your conda environment

Step 1: Install nnUNet v2

First, install nnUNet v2 in your conda environment:

# Activate your environment
conda activate seg_app

# Install nnUNet v2
pip install nnunetv2

Or uncomment the line in requirements.txt:

# nnunetv2>=2.2

And run:

pip install -r requirements.txt

Step 2: Locate Your nnUNet Checkpoint Path

nnUNet training creates a specific folder structure. You need the path to the trainer output folder:

nnUNet_results/
└── Dataset###_Name/                          # e.g., Dataset001_BrainLesion
    └── nnUNetTrainer__nnUNetPlans__3d_fullres/   # ← THIS IS THE PATH YOU NEED
        β”œβ”€β”€ plans.json                        # Training plans
        β”œβ”€β”€ dataset.json                      # Dataset configuration
        β”œβ”€β”€ fold_0/
        β”‚   β”œβ”€β”€ checkpoint_final.pth          # Model weights
        β”‚   └── checkpoint_best.pth           # Best validation weights
        β”œβ”€β”€ fold_1/
        β”‚   └── ...
        └── ...

Your checkpoint path should point to the trainer folder, for example:

/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres

Step 3: Configure nnUNet in Settings

Edit the file seg_app/config/settings.py and update the NNUNET_CONFIG:

# Global nnUNet configuration instance
NNUNET_CONFIG = nnUNetConfig(
    # Set the path to your trained nnUNet model folder
    checkpoint_path="/mnt/nvme0n1/Dataset/segmentation/CoreLesion/nnUNet_results/Dataset001_BrainLesion/nnUNetTrainer__nnUNetPlans__3d_fullres",
    
    # Optional: which folds to use for inference
    # "all" = use all available folds (ensemble), or [0] for single fold
    use_folds="all",
    
    # Optional: test-time augmentation (mirroring)
    # True = more accurate but slower, False = faster inference
    use_mirroring=True,
    
    # Optional: customize the display name in the UI dropdown
    display_name="nnU-Net (Brain Lesion)",
)

Step 4: Verify the Integration

After configuration, restart the application and verify nnUNet appears in the model dropdown:

# For Gradio UI
python app.py

# For VTK.js Slicer (FastAPI backend)
uvicorn seg_app.backend.api:app --reload --port 8000

Open the web interface and check that "nnU-Net (Brain Lesion)" appears in the model selection dropdown.

Step 5: Test Inference

  1. Upload a NIfTI volume (.nii or .nii.gz)
  2. Select "nnU-Net (Brain Lesion)" from the model dropdown
  3. Click "Run Segmentation"
  4. View the segmentation overlay and volume metrics

nnUNet Configuration Options

Option Default Description
checkpoint_path None Required. Path to nnUNet trainer output folder
use_folds "all" Which folds to use. "all" for ensemble, [0] for single fold
use_mirroring True Test-time augmentation. Improves accuracy but ~4x slower
display_name "nnU-Net (Brain Lesion)" Name shown in UI dropdown

Advanced: Multiple nnUNet Models

To add multiple nnUNet models for different tasks, you can modify the registration logic in seg_app/models/nnunet_wrapper.py. The register_nnunet() function can be extended to register multiple models:

# Example: Register multiple nnUNet models
def register_nnunet() -> None:
    from seg_app.inference.model_registry import register_model
    
    # Model 1: Brain Lesion
    brain_config = ModelConfig(
        model_id="nnunet-brain-lesion",
        local_path="/path/to/nnUNet_results/DatasetBrain/nnUNetTrainer__nnUNetPlans__3d_fullres",
        device="cuda",
    )
    register_model("nnunet-brain-lesion", nnUNetWrapper, brain_config)
    
    # Model 2: Liver (example)
    liver_config = ModelConfig(
        model_id="nnunet-liver",
        local_path="/path/to/nnUNet_results/DatasetLiver/nnUNetTrainer__nnUNetPlans__3d_fullres",
        device="cuda",
    )
    register_model("nnunet-liver", nnUNetWrapper, liver_config)

Troubleshooting nnUNet Integration

"nnunetv2 is not installed"

pip install nnunetv2

"plans.json not found"

  • Ensure checkpoint_path points to the trainer folder (not the dataset folder)
  • The path should contain plans.json or nnUNetPlans.json

"checkpoint_final.pth not found"

  • Verify training completed successfully
  • Check if only checkpoint_best.pth exists (modify the wrapper to use it)

"Model not appearing in dropdown"

  • Check that checkpoint_path is not None in NNUNET_CONFIG
  • Look for warnings in the terminal when starting the app

Out of memory during inference

  • Reduce use_mirroring to False (reduces memory by ~4x)
  • Use fewer folds: use_folds=[0] instead of "all"
  • Reduce input volume size

Files Modified for nnUNet Integration

File Purpose
seg_app/models/nnunet_wrapper.py nnUNet model wrapper implementing BaseModel interface
seg_app/config/settings.py nnUNet configuration (checkpoint path, inference settings)
seg_app/inference/model_registry.py Registers nnUNet model during lazy initialization
seg_app/inference/orchestrator.py Adds nnUNet to available models list
requirements.txt Optional nnunetv2 dependency