Spaces:
Running
Running
Add documentation site
Browse filesAdd full project documentation and tooling.
- .github/workflows/docs.yml +26 -0
- .gitignore +2 -1
- README.md +122 -1
- docs/api/data.md +140 -0
- docs/api/models.md +104 -0
- docs/api/preprocessing.md +103 -0
- docs/architecture.md +97 -0
- docs/assets/logo.svg +109 -0
- docs/configuration.md +147 -0
- docs/contributing.md +62 -0
- docs/getting-started.md +104 -0
- docs/index.md +35 -0
- docs/inference.md +81 -0
- docs/pipeline.md +99 -0
- mkdocs.yml +67 -0
.github/workflows/docs.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Deploy Documentation
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
paths:
|
| 7 |
+
- 'docs/**'
|
| 8 |
+
- 'mkdocs.yml'
|
| 9 |
+
workflow_dispatch:
|
| 10 |
+
|
| 11 |
+
permissions:
|
| 12 |
+
contents: write
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
deploy:
|
| 16 |
+
runs-on: ubuntu-latest
|
| 17 |
+
steps:
|
| 18 |
+
- uses: actions/checkout@v4
|
| 19 |
+
|
| 20 |
+
- uses: actions/setup-python@v5
|
| 21 |
+
with:
|
| 22 |
+
python-version: '3.11'
|
| 23 |
+
|
| 24 |
+
- run: pip install mkdocs-material
|
| 25 |
+
|
| 26 |
+
- run: mkdocs gh-deploy --force
|
.gitignore
CHANGED
|
@@ -6,4 +6,5 @@ temp.ipynb
|
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
*.pyc
|
| 9 |
-
.ruff_cache
|
|
|
|
|
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
| 8 |
*.pyc
|
| 9 |
+
.ruff_cache
|
| 10 |
+
site/
|
README.md
CHANGED
|
@@ -1 +1,122 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="docs/assets/logo.svg" alt="WSAttention-Prostate Logo" width="240">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<img src="https://img.shields.io/badge/python-3.11-blue?logo=python&logoColor=white" alt="Python 3.11">
|
| 7 |
+
<img src="https://img.shields.io/badge/pytorch-2.5-ee4c2c?logo=pytorch&logoColor=white" alt="PyTorch 2.5">
|
| 8 |
+
<img src="https://img.shields.io/badge/MONAI-1.4-3ddc84" alt="MONAI 1.4">
|
| 9 |
+
<img src="https://img.shields.io/badge/license-Apache%202.0-green" alt="License">
|
| 10 |
+
<a href="https://ai-assisted-healthcare.github.io/WSAttention-Prostate/"><img src="https://img.shields.io/badge/docs-mkdocs-blue" alt="Docs"></a>
|
| 11 |
+
</p>
|
| 12 |
+
|
| 13 |
+
# WSAttention-Prostate
|
| 14 |
+
|
| 15 |
+
**Weakly-supervised attention-based 3D Multiple Instance Learning for prostate cancer risk prediction on multiparametric MRI.**
|
| 16 |
+
|
| 17 |
+
WSAttention-Prostate is a two-stage deep learning pipeline that predicts clinically significant prostate cancer (csPCa) risk from T2-weighted, DWI, and ADC MRI sequences. It uses 3D patch-based Multiple Instance Learning with transformer attention to first classify PI-RADS scores, then predict csPCa risk β all without requiring lesion-level annotations.
|
| 18 |
+
|
| 19 |
+
## Key Features
|
| 20 |
+
|
| 21 |
+
- **Weakly-supervised attention** β Heatmap-guided patch sampling and cosine-similarity attention loss replace the need for voxel-level labels
|
| 22 |
+
- **3D Multiple Instance Learning** β Extracts volumetric patches from MRI scans and aggregates them via transformer + attention pooling
|
| 23 |
+
- **Two-stage pipeline** β Stage 1 trains a 4-class PI-RADS classifier; Stage 2 freezes its backbone and trains a binary csPCa head
|
| 24 |
+
- **Multi-seed confidence intervals** β Runs 20 random seeds and reports 95% CI on AUC, sensitivity, and specificity
|
| 25 |
+
- **End-to-end preprocessing** β Registration, segmentation, histogram matching, and heatmap generation in a single configurable pipeline
|
| 26 |
+
|
| 27 |
+
## Pipeline Overview
|
| 28 |
+
|
| 29 |
+
```mermaid
|
| 30 |
+
flowchart LR
|
| 31 |
+
A[Raw MRI\nT2 + DWI + ADC] --> B[Preprocessing]
|
| 32 |
+
B --> C[Stage 1:\nPI-RADS Classification]
|
| 33 |
+
C --> D[Stage 2:\ncsPCa Prediction]
|
| 34 |
+
D --> E[Risk Score\n+ Top-5 Patches]
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Quick Start
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
git clone https://github.com/ai-assisted-healthcare/WSAttention-Prostate.git
|
| 41 |
+
cd WSAttention-Prostate
|
| 42 |
+
pip install -r requirements.txt
|
| 43 |
+
pytest tests/
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Usage
|
| 47 |
+
|
| 48 |
+
### Preprocessing
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python preprocess_main.py --config config/config_preprocess.yaml \
|
| 52 |
+
--steps register_and_crop get_segmentation_mask histogram_match get_heatmap
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
### PI-RADS Training
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
python run_pirads.py --mode train --config config/config_pirads_train.yaml
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### csPCa Training
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python run_cspca.py --mode train --config config/config_cspca_train.yaml
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Inference
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python run_pirads.py --mode test --config config/config_pirads_test.yaml --checkpoint <path>
|
| 71 |
+
python run_cspca.py --mode test --config config/config_cspca_test.yaml --checkpoint_cspca <path>
|
| 72 |
+
python run_inference.py --config config/config_preprocess.yaml
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
See the [full documentation](https://ai-assisted-healthcare.github.io/WSAttention-Prostate/) for detailed configuration options and data format requirements.
|
| 76 |
+
|
| 77 |
+
## Project Structure
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
WSAttention-Prostate/
|
| 81 |
+
βββ run_pirads.py # PI-RADS training/testing entry point
|
| 82 |
+
βββ run_cspca.py # csPCa training/testing entry point
|
| 83 |
+
βββ run_inference.py # Full inference pipeline
|
| 84 |
+
βββ preprocess_main.py # Preprocessing entry point
|
| 85 |
+
βββ config/ # YAML configuration files
|
| 86 |
+
βββ src/
|
| 87 |
+
β βββ model/
|
| 88 |
+
β β βββ MIL.py # MILModel_3D β core MIL architecture
|
| 89 |
+
β β βββ csPCa_model.py # csPCa_Model + SimpleNN head
|
| 90 |
+
β βββ data/
|
| 91 |
+
β β βββ data_loader.py # MONAI data pipeline
|
| 92 |
+
β β βββ custom_transforms.py
|
| 93 |
+
β βββ train/
|
| 94 |
+
β β βββ train_pirads.py # PI-RADS training loop
|
| 95 |
+
β β βββ train_cspca.py # csPCa training loop
|
| 96 |
+
β βββ preprocessing/ # Registration, segmentation, heatmaps
|
| 97 |
+
β βββ utils.py # Shared utilities and step validation
|
| 98 |
+
βββ tests/
|
| 99 |
+
βββ dataset/ # Reference images for histogram matching
|
| 100 |
+
βββ models/ # Downloaded checkpoints (not in repo)
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Architecture
|
| 104 |
+
|
| 105 |
+
Input MRI patches are processed independently through a 3D ResNet18 backbone, then aggregated via a transformer encoder and attention pooling:
|
| 106 |
+
|
| 107 |
+
```mermaid
|
| 108 |
+
flowchart TD
|
| 109 |
+
A["Input [B, N, C, D, H, W]"] --> B["Reshape to [B*N, C, D, H, W]"]
|
| 110 |
+
B --> C[ResNet18-3D Backbone]
|
| 111 |
+
C --> D["Reshape to [B, N, 512]"]
|
| 112 |
+
D --> E[Transformer Encoder\n4 layers, 8 heads]
|
| 113 |
+
E --> F[Attention Pooling\n512 β 2048 β 1]
|
| 114 |
+
F --> G["Weighted Sum [B, 512]"]
|
| 115 |
+
G --> H["FC Head [B, num_classes]"]
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
For csPCa prediction, the backbone is frozen and a 3-layer MLP (`512 β 256 β 128 β 1`) replaces the classification head.
|
| 119 |
+
|
| 120 |
+
## License
|
| 121 |
+
|
| 122 |
+
Apache-2.0 β see [LICENSE](LICENSE).
|
docs/api/data.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Loading Reference
|
| 2 |
+
|
| 3 |
+
## get_dataloader
|
| 4 |
+
|
| 5 |
+
```python
|
| 6 |
+
def get_dataloader(args, split: Literal["train", "test"]) -> DataLoader
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
Creates a PyTorch DataLoader with MONAI transforms and persistent caching.
|
| 10 |
+
|
| 11 |
+
**Parameters:**
|
| 12 |
+
|
| 13 |
+
| Parameter | Description |
|
| 14 |
+
|-----------|-------------|
|
| 15 |
+
| `args` | Namespace with `dataset_json`, `data_root`, `tile_size`, `tile_count`, `depth`, `use_heatmap`, `batch_size`, `workers`, `dry_run`, `logdir` |
|
| 16 |
+
| `split` | `"train"` or `"test"` |
|
| 17 |
+
|
| 18 |
+
**Behavior:**
|
| 19 |
+
|
| 20 |
+
- Loads data lists from a MONAI decathlon-format JSON
|
| 21 |
+
- In `dry_run` mode, limits to 8 samples
|
| 22 |
+
- Uses `PersistentDataset` with cache stored at `<logdir>/cache/<split>/`
|
| 23 |
+
- Training split is shuffled; test split is not
|
| 24 |
+
- Uses `list_data_collate` to stack patches into `[B, N, C, D, H, W]`
|
| 25 |
+
|
| 26 |
+
## Transform Pipeline
|
| 27 |
+
|
| 28 |
+
Two variants depending on `args.use_heatmap`:
|
| 29 |
+
|
| 30 |
+
### With Heatmaps (default)
|
| 31 |
+
|
| 32 |
+
| Step | Transform | Description |
|
| 33 |
+
|------|-----------|-------------|
|
| 34 |
+
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC, heatmap (ITKReader, channel-first) |
|
| 35 |
+
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
|
| 36 |
+
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC β 3-channel image |
|
| 37 |
+
| 4 | `NormalizeIntensity_customd` | Z-score normalize per channel using mask-only statistics |
|
| 38 |
+
| 5 | `ElementwiseProductd` | Multiply mask * heatmap β `final_heatmap` |
|
| 39 |
+
| 6 | `RandWeightedCropd` | Extract N patches weighted by `final_heatmap` |
|
| 40 |
+
| 7 | `EnsureTyped` | Cast labels to float32 |
|
| 41 |
+
| 8 | `Transposed` | Reorder image dims for 3D convolution |
|
| 42 |
+
| 9 | `DeleteItemsd` | Remove intermediate keys (mask, dwi, adc, heatmap) |
|
| 43 |
+
| 10 | `ToTensord` | Convert to PyTorch tensors |
|
| 44 |
+
|
| 45 |
+
### Without Heatmaps
|
| 46 |
+
|
| 47 |
+
| Step | Transform | Description |
|
| 48 |
+
|------|-----------|-------------|
|
| 49 |
+
| 1 | `LoadImaged` | Load T2, mask, DWI, ADC |
|
| 50 |
+
| 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
|
| 51 |
+
| 3 | `ConcatItemsd` | Stack T2 + DWI + ADC β 3-channel image |
|
| 52 |
+
| 4 | `NormalizeIntensityd` | Standard channel-wise normalization (MONAI built-in) |
|
| 53 |
+
| 5 | `RandCropByPosNegLabeld` | Extract N patches from positive (mask) regions |
|
| 54 |
+
| 6 | `EnsureTyped` | Cast labels to float32 |
|
| 55 |
+
| 7 | `Transposed` | Reorder image dims |
|
| 56 |
+
| 8 | `DeleteItemsd` | Remove intermediate keys |
|
| 57 |
+
| 9 | `ToTensord` | Convert to tensors |
|
| 58 |
+
|
| 59 |
+
## list_data_collate
|
| 60 |
+
|
| 61 |
+
```python
|
| 62 |
+
def list_data_collate(batch: Sequence) -> dict
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Custom collation function that stacks per-patient patch lists into batch tensors.
|
| 66 |
+
|
| 67 |
+
Each sample from the dataset is a list of N patch dictionaries. This function:
|
| 68 |
+
|
| 69 |
+
1. Stacks `image` across patches: `[N, C, D, H, W]` per sample
|
| 70 |
+
2. Stacks `final_heatmap` if present
|
| 71 |
+
3. Applies PyTorch's `default_collate` to form the batch dimension
|
| 72 |
+
|
| 73 |
+
Result: `{"image": [B, N, C, D, H, W], "label": [B], ...}`
|
| 74 |
+
|
| 75 |
+
## Custom Transforms
|
| 76 |
+
|
| 77 |
+
### ClipMaskIntensityPercentilesd
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
ClipMaskIntensityPercentilesd(
|
| 81 |
+
keys: KeysCollection,
|
| 82 |
+
mask_key: str,
|
| 83 |
+
lower: float | None,
|
| 84 |
+
upper: float | None,
|
| 85 |
+
sharpness_factor: float | None = None,
|
| 86 |
+
channel_wise: bool = False,
|
| 87 |
+
dtype: DtypeLike = np.float32,
|
| 88 |
+
)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
Clips image intensity to percentiles computed only from the **masked region**. Supports both hard clipping (default) and soft clipping (via `sharpness_factor`).
|
| 92 |
+
|
| 93 |
+
### NormalizeIntensity_customd
|
| 94 |
+
|
| 95 |
+
```python
|
| 96 |
+
NormalizeIntensity_customd(
|
| 97 |
+
keys: KeysCollection,
|
| 98 |
+
mask_key: str,
|
| 99 |
+
subtrahend: NdarrayOrTensor | None = None,
|
| 100 |
+
divisor: NdarrayOrTensor | None = None,
|
| 101 |
+
nonzero: bool = False,
|
| 102 |
+
channel_wise: bool = False,
|
| 103 |
+
dtype: DtypeLike = np.float32,
|
| 104 |
+
)
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Z-score normalization where mean and standard deviation are computed only from **masked voxels**. Supports channel-wise normalization.
|
| 108 |
+
|
| 109 |
+
### ElementwiseProductd
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
ElementwiseProductd(
|
| 113 |
+
keys: KeysCollection,
|
| 114 |
+
output_key: str,
|
| 115 |
+
)
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Computes the element-wise product of two arrays from the data dictionary and stores the result in `output_key`. Used to combine the prostate mask with the attention heatmap.
|
| 119 |
+
|
| 120 |
+
## Dataset JSON Format
|
| 121 |
+
|
| 122 |
+
The pipeline expects a MONAI decathlon-format JSON file:
|
| 123 |
+
|
| 124 |
+
```json
|
| 125 |
+
{
|
| 126 |
+
"train": [
|
| 127 |
+
{
|
| 128 |
+
"image": "relative/path/to/t2.nrrd",
|
| 129 |
+
"dwi": "relative/path/to/dwi.nrrd",
|
| 130 |
+
"adc": "relative/path/to/adc.nrrd",
|
| 131 |
+
"mask": "relative/path/to/mask.nrrd",
|
| 132 |
+
"heatmap": "relative/path/to/heatmap.nrrd",
|
| 133 |
+
"label": 2
|
| 134 |
+
}
|
| 135 |
+
],
|
| 136 |
+
"test": [...]
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
Paths are relative to `data_root`. The `heatmap` key is only required when `use_heatmap=True`.
|
docs/api/models.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Models Reference
|
| 2 |
+
|
| 3 |
+
## MILModel_3D
|
| 4 |
+
|
| 5 |
+
```python
|
| 6 |
+
class MILModel_3D(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
num_classes: int,
|
| 10 |
+
mil_mode: str = "att",
|
| 11 |
+
pretrained: bool = True,
|
| 12 |
+
backbone: str | nn.Module | None = None,
|
| 13 |
+
backbone_num_features: int | None = None,
|
| 14 |
+
trans_blocks: int = 4,
|
| 15 |
+
trans_dropout: float = 0.0,
|
| 16 |
+
)
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
**Constructor arguments:**
|
| 20 |
+
|
| 21 |
+
| Argument | Type | Default | Description |
|
| 22 |
+
|----------|------|---------|-------------|
|
| 23 |
+
| `num_classes` | `int` | β | Number of output classes |
|
| 24 |
+
| `mil_mode` | `str` | `"att"` | MIL aggregation mode |
|
| 25 |
+
| `pretrained` | `bool` | `True` | Use pretrained backbone weights |
|
| 26 |
+
| `backbone` | `str \| nn.Module \| None` | `None` | Backbone CNN (None = ResNet18-3D) |
|
| 27 |
+
| `backbone_num_features` | `int \| None` | `None` | Output features of custom backbone |
|
| 28 |
+
| `trans_blocks` | `int` | `4` | Number of transformer encoder layers |
|
| 29 |
+
| `trans_dropout` | `float` | `0.0` | Transformer dropout rate |
|
| 30 |
+
|
| 31 |
+
**MIL modes:**
|
| 32 |
+
|
| 33 |
+
| Mode | Description |
|
| 34 |
+
|------|-------------|
|
| 35 |
+
| `mean` | Average logits across all patches β equivalent to pure CNN |
|
| 36 |
+
| `max` | Keep only the max-probability instance for loss |
|
| 37 |
+
| `att` | Attention-based MIL ([Ilse et al., 2018](https://arxiv.org/abs/1802.04712)) |
|
| 38 |
+
| `att_trans` | Transformer + attention MIL ([Shao et al., 2021](https://arxiv.org/abs/2111.01556)) |
|
| 39 |
+
| `att_trans_pyramid` | Pyramid transformer using intermediate ResNet layers |
|
| 40 |
+
|
| 41 |
+
**Key methods:**
|
| 42 |
+
|
| 43 |
+
- `forward(x, no_head=False)` β Full forward pass. If `no_head=True`, returns patch-level features `[B, N, 512]` before transformer and attention pooling (used during attention loss computation).
|
| 44 |
+
- `calc_head(x)` β Applies the MIL aggregation and classification head to patch features.
|
| 45 |
+
|
| 46 |
+
**Example:**
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
import torch
|
| 50 |
+
from src.model.MIL import MILModel_3D
|
| 51 |
+
|
| 52 |
+
model = MILModel_3D(num_classes=4, mil_mode="att_trans")
|
| 53 |
+
# Input: [batch, patches, channels, depth, height, width]
|
| 54 |
+
x = torch.randn(2, 24, 3, 3, 64, 64)
|
| 55 |
+
logits = model(x) # [2, 4]
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## csPCa_Model
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
class csPCa_Model(nn.Module):
|
| 62 |
+
def __init__(self, backbone: nn.Module)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Wraps a pre-trained `MILModel_3D` backbone for binary csPCa prediction. The backbone's feature extractor, transformer, and attention mechanism are reused. The original classification head (`myfc`) is replaced by a `SimpleNN`.
|
| 66 |
+
|
| 67 |
+
**Attributes:**
|
| 68 |
+
|
| 69 |
+
| Attribute | Type | Description |
|
| 70 |
+
|-----------|------|-------------|
|
| 71 |
+
| `backbone` | `MILModel_3D` | Frozen PI-RADS backbone |
|
| 72 |
+
| `fc_cspca` | `SimpleNN` | Binary classification head |
|
| 73 |
+
| `fc_dim` | `int` | Feature dimension (512 for ResNet18) |
|
| 74 |
+
|
| 75 |
+
**Example:**
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
import torch
|
| 79 |
+
from src.model.MIL import MILModel_3D
|
| 80 |
+
from src.model.csPCa_model import csPCa_Model
|
| 81 |
+
|
| 82 |
+
backbone = MILModel_3D(num_classes=4, mil_mode="att_trans")
|
| 83 |
+
model = csPCa_Model(backbone=backbone)
|
| 84 |
+
|
| 85 |
+
x = torch.randn(2, 24, 3, 3, 64, 64)
|
| 86 |
+
prob = model(x) # [2, 1] β sigmoid probabilities
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## SimpleNN
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
class SimpleNN(nn.Module):
|
| 93 |
+
def __init__(self, input_dim: int)
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
A lightweight MLP for binary classification:
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
Linear(input_dim, 256) β ReLU
|
| 100 |
+
Linear(256, 128) β ReLU β Dropout(0.3)
|
| 101 |
+
Linear(128, 1) β Sigmoid
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Input: `[B, input_dim]` β Output: `[B, 1]` (probability).
|
docs/api/preprocessing.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Preprocessing Reference
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
Preprocessing is orchestrated by `preprocess_main.py`, which runs steps in sequence. Each step receives and returns the `args` namespace, updating directory paths as it goes.
|
| 6 |
+
|
| 7 |
+
## Step Dependencies
|
| 8 |
+
|
| 9 |
+
| Step | Requires |
|
| 10 |
+
|------|----------|
|
| 11 |
+
| `register_and_crop` | β |
|
| 12 |
+
| `get_segmentation_mask` | `register_and_crop` |
|
| 13 |
+
| `histogram_match` | `register_and_crop`, `get_segmentation_mask` |
|
| 14 |
+
| `get_heatmap` | `register_and_crop`, `get_segmentation_mask`, `histogram_match` |
|
| 15 |
+
|
| 16 |
+
Dependencies are validated at runtime β the pipeline will exit with an error if steps are out of order.
|
| 17 |
+
|
| 18 |
+
## register_files
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
def register_files(args) -> args
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Registers and crops T2, DWI, and ADC images to a standardized spacing and size.
|
| 25 |
+
|
| 26 |
+
**Process:**
|
| 27 |
+
|
| 28 |
+
1. Reads images from `args.t2_dir`, `args.dwi_dir`, `args.adc_dir`
|
| 29 |
+
2. Resamples to spacing `(0.4, 0.4, 3.0)` mm using `picai_prep.Sample`
|
| 30 |
+
3. Center-crops with `args.margin` (default 0.2) in x/y dimensions
|
| 31 |
+
4. Saves to `<output_dir>/t2_registered/`, `DWI_registered/`, `ADC_registered/`
|
| 32 |
+
|
| 33 |
+
**Updates `args`:** `t2_dir`, `dwi_dir`, `adc_dir` β registered directories.
|
| 34 |
+
|
| 35 |
+
## get_segmask
|
| 36 |
+
|
| 37 |
+
```python
|
| 38 |
+
def get_segmask(args) -> args
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Generates prostate segmentation masks from T2W images using a pre-trained model.
|
| 42 |
+
|
| 43 |
+
**Process:**
|
| 44 |
+
|
| 45 |
+
1. Loads model config from `<project_dir>/config/inference.json`
|
| 46 |
+
2. Loads checkpoint from `<project_dir>/models/prostate_segmentation_model.pt`
|
| 47 |
+
3. Applies MONAI transforms: orientation (RAS), spacing (0.5 mm isotropic), intensity normalization
|
| 48 |
+
4. Runs inference and inverts transforms to original space
|
| 49 |
+
5. Post-processes: retains only top 10 slices by non-zero voxel count
|
| 50 |
+
6. Saves NRRD masks to `<output_dir>/prostate_mask/`
|
| 51 |
+
|
| 52 |
+
**Updates `args`:** adds `seg_dir`.
|
| 53 |
+
|
| 54 |
+
## histmatch
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
def histmatch(args) -> args
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Matches the intensity histogram of each modality to a reference image.
|
| 61 |
+
|
| 62 |
+
**Process:**
|
| 63 |
+
|
| 64 |
+
1. Reads reference images from `<project_dir>/dataset/` (`t2_reference.nrrd`, `dwi_reference.nrrd`, `adc_reference.nrrd`, `prostate_segmentation_reference.nrrd`)
|
| 65 |
+
2. For each patient, matches histograms within the prostate mask using `skimage.exposure.match_histograms`
|
| 66 |
+
3. Saves to `<output_dir>/t2_histmatched/`, `DWI_histmatched/`, `ADC_histmatched/`
|
| 67 |
+
|
| 68 |
+
**Updates `args`:** `t2_dir`, `dwi_dir`, `adc_dir` β histogram-matched directories.
|
| 69 |
+
|
| 70 |
+
### get_histmatched
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
def get_histmatched(
|
| 74 |
+
data: np.ndarray,
|
| 75 |
+
ref_data: np.ndarray,
|
| 76 |
+
mask: np.ndarray,
|
| 77 |
+
ref_mask: np.ndarray,
|
| 78 |
+
) -> np.ndarray
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Low-level function that performs histogram matching on masked regions only. Unmasked pixels remain unchanged.
|
| 82 |
+
|
| 83 |
+
## get_heatmap
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
def get_heatmap(args) -> args
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Generates combined DWI/ADC attention heatmaps.
|
| 90 |
+
|
| 91 |
+
**Process:**
|
| 92 |
+
|
| 93 |
+
1. For each file, reads DWI, ADC, and prostate mask
|
| 94 |
+
2. Computes DWI heatmap: `(dwi - min) / (max - min)` within mask
|
| 95 |
+
3. Computes ADC heatmap: `(max - adc) / (max - min)` within mask (inverted β low ADC = high attention)
|
| 96 |
+
4. Combines via element-wise multiplication
|
| 97 |
+
5. Re-normalizes to [0, 1]
|
| 98 |
+
6. Saves to `<output_dir>/heatmaps/`
|
| 99 |
+
|
| 100 |
+
**Updates `args`:** adds `heatmapdir`.
|
| 101 |
+
|
| 102 |
+
!!! info "Edge cases"
|
| 103 |
+
If all values within the mask are identical for a modality (DWI or ADC), that modality's heatmap is skipped. If both are constant, the heatmap defaults to all ones.
|
docs/architecture.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Architecture
|
| 2 |
+
|
| 3 |
+
## Tensor Shape Convention
|
| 4 |
+
|
| 5 |
+
Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
|
| 6 |
+
|
| 7 |
+
| Dim | Meaning | Typical Value |
|
| 8 |
+
|-----|---------|---------------|
|
| 9 |
+
| B | Batch size | 4β8 |
|
| 10 |
+
| N | Number of patches (instances) | 24 |
|
| 11 |
+
| C | Channels (T2 + DWI + ADC) | 3 |
|
| 12 |
+
| D | Depth (slices per patch) | 3 |
|
| 13 |
+
| H | Patch height | 64 |
|
| 14 |
+
| W | Patch width | 64 |
|
| 15 |
+
|
| 16 |
+
## MILModel_3D
|
| 17 |
+
|
| 18 |
+
The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
|
| 19 |
+
|
| 20 |
+
```mermaid
|
| 21 |
+
flowchart TD
|
| 22 |
+
A["Input [B, N, C, D, H, W]"] --> B["Reshape to [B*N, C, D, H, W]"]
|
| 23 |
+
B --> C[ResNet18-3D Backbone]
|
| 24 |
+
C --> D["Reshape to [B, N, 512]"]
|
| 25 |
+
D --> E[Transformer Encoder\n4 layers, 8 heads]
|
| 26 |
+
E --> F[Attention Pooling\n512 β 2048 β 1]
|
| 27 |
+
F --> G["Weighted Sum [B, 512]"]
|
| 28 |
+
G --> H["FC Head [B, num_classes]"]
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### Forward Pass
|
| 32 |
+
|
| 33 |
+
1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
|
| 34 |
+
|
| 35 |
+
2. **Transformer**: Features are reshaped to `[B, N, 512]`, permuted to `[N, B, 512]` for the transformer encoder (4 layers, 8 attention heads), then permuted back.
|
| 36 |
+
|
| 37 |
+
3. **Attention**: A two-layer attention network (`512 β 2048 β 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
|
| 38 |
+
|
| 39 |
+
4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer.
|
| 40 |
+
|
| 41 |
+
### MIL Modes
|
| 42 |
+
|
| 43 |
+
| Mode | Aggregation Strategy |
|
| 44 |
+
|------|---------------------|
|
| 45 |
+
| `mean` | Average logits across patches |
|
| 46 |
+
| `max` | Max logits across patches |
|
| 47 |
+
| `att` | Attention-weighted feature pooling |
|
| 48 |
+
| `att_trans` | Transformer encoder + attention pooling (primary mode) |
|
| 49 |
+
| `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention |
|
| 50 |
+
|
| 51 |
+
The default and primary mode is `att_trans`.
|
| 52 |
+
|
| 53 |
+
## csPCa_Model
|
| 54 |
+
|
| 55 |
+
Wraps a frozen `MILModel_3D` backbone and replaces the classification head:
|
| 56 |
+
|
| 57 |
+
```mermaid
|
| 58 |
+
flowchart TD
|
| 59 |
+
A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone\n(ResNet18 + Transformer)"]
|
| 60 |
+
B --> C["Pooled Features [B, 512]"]
|
| 61 |
+
C --> D["SimpleNN Head\n512 β 256 β 128 β 1"]
|
| 62 |
+
D --> E["Sigmoid β csPCa Probability"]
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### SimpleNN
|
| 66 |
+
|
| 67 |
+
```
|
| 68 |
+
Linear(512, 256) β ReLU
|
| 69 |
+
Linear(256, 128) β ReLU β Dropout(0.3)
|
| 70 |
+
Linear(128, 1) β Sigmoid
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
During csPCa training, the backbone's `net` (ResNet18), `transformer`, and `myfc` parameters are frozen. The `attention` module and `SimpleNN` head remain trainable.
|
| 74 |
+
|
| 75 |
+
## Attention Loss
|
| 76 |
+
|
| 77 |
+
During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
total_loss = class_loss + lambda_att * attention_loss
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
- **Classification loss**: Standard CrossEntropy on PI-RADS labels
|
| 84 |
+
- **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)`
|
| 85 |
+
- Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
|
| 86 |
+
- PI-RADS 2 samples get uniform attention (no expected lesion)
|
| 87 |
+
- `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
|
| 88 |
+
- The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
|
| 89 |
+
|
| 90 |
+
## Patch Extraction
|
| 91 |
+
|
| 92 |
+
Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
|
| 93 |
+
|
| 94 |
+
- **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β regions with high DWI and low ADC are sampled more frequently
|
| 95 |
+
- **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
|
| 96 |
+
|
| 97 |
+
Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
|
docs/assets/logo.svg
ADDED
|
|
docs/configuration.md
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration
|
| 2 |
+
|
| 3 |
+
## Config System
|
| 4 |
+
|
| 5 |
+
Configuration follows a three-level hierarchy:
|
| 6 |
+
|
| 7 |
+
1. **CLI defaults** β Argparse defaults in `run_pirads.py`, `run_cspca.py`, etc.
|
| 8 |
+
2. **YAML overrides** β Values from `--config <file>.yaml` override CLI defaults
|
| 9 |
+
3. **SLURM job name** β If `SLURM_JOB_NAME` is set, it overrides `run_name`
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# CLI defaults are overridden by YAML config
|
| 13 |
+
python run_pirads.py --mode train --config config/config_pirads_train.yaml
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
!!! note
|
| 17 |
+
YAML values **always override** CLI defaults for any key present in the YAML file (`args.__dict__.update(config)`). To override a YAML value, edit the YAML file or omit the key from YAML so the CLI default is used.
|
| 18 |
+
|
| 19 |
+
## PI-RADS Training Parameters
|
| 20 |
+
|
| 21 |
+
| Parameter | Default | Description |
|
| 22 |
+
|-----------|---------|-------------|
|
| 23 |
+
| `mode` | β | `train` or `test` (required) |
|
| 24 |
+
| `config` | β | Path to YAML config file |
|
| 25 |
+
| `data_root` | β | Root folder of images |
|
| 26 |
+
| `dataset_json` | β | Path to dataset JSON file |
|
| 27 |
+
| `num_classes` | `4` | Number of output classes (PI-RADS 2β5) |
|
| 28 |
+
| `mil_mode` | `att_trans` | MIL algorithm (`mean`, `max`, `att`, `att_trans`, `att_trans_pyramid`) |
|
| 29 |
+
| `tile_count` | `24` | Number of patches per scan |
|
| 30 |
+
| `tile_size` | `64` | Patch spatial size in pixels |
|
| 31 |
+
| `depth` | `3` | Number of slices per patch |
|
| 32 |
+
| `use_heatmap` | `True` | Enable heatmap-guided patch sampling |
|
| 33 |
+
| `workers` | `2` | DataLoader workers |
|
| 34 |
+
| `checkpoint` | `None` | Path to resume from checkpoint |
|
| 35 |
+
| `epochs` | `50` | Max training epochs |
|
| 36 |
+
| `early_stop` | `40` | Epochs without improvement before stopping |
|
| 37 |
+
| `batch_size` | `4` | Scans per batch |
|
| 38 |
+
| `optim_lr` | `3e-5` | Base learning rate |
|
| 39 |
+
| `weight_decay` | `0` | Optimizer weight decay |
|
| 40 |
+
| `amp` | `False` | Enable automatic mixed precision |
|
| 41 |
+
| `val_every` | `1` | Validation frequency (epochs) |
|
| 42 |
+
| `wandb` | `False` | Enable Weights & Biases logging |
|
| 43 |
+
| `project_name` | `Classification_prostate` | W&B project name |
|
| 44 |
+
| `run_name` | `train_pirads` | Run name for logging |
|
| 45 |
+
| `dry_run` | `False` | Quick test mode |
|
| 46 |
+
|
| 47 |
+
## csPCa Training Parameters
|
| 48 |
+
|
| 49 |
+
| Parameter | Default | Description |
|
| 50 |
+
|-----------|---------|-------------|
|
| 51 |
+
| `mode` | β | `train` or `test` (required) |
|
| 52 |
+
| `config` | β | Path to YAML config file |
|
| 53 |
+
| `data_root` | β | Root folder of images |
|
| 54 |
+
| `dataset_json` | β | Path to dataset JSON file |
|
| 55 |
+
| `num_classes` | `4` | PI-RADS classes (for backbone initialization) |
|
| 56 |
+
| `mil_mode` | `att_trans` | MIL algorithm for backbone |
|
| 57 |
+
| `tile_count` | `24` | Number of patches per scan |
|
| 58 |
+
| `tile_size` | `64` | Patch spatial size |
|
| 59 |
+
| `depth` | `3` | Slices per patch |
|
| 60 |
+
| `use_heatmap` | `True` | Enable heatmap-guided patch sampling |
|
| 61 |
+
| `workers` | `2` | DataLoader workers |
|
| 62 |
+
| `checkpoint_pirads` | β | Path to pre-trained PI-RADS model (required for train) |
|
| 63 |
+
| `checkpoint_cspca` | β | Path to csPCa checkpoint (required for test) |
|
| 64 |
+
| `epochs` | `30` | Max training epochs |
|
| 65 |
+
| `batch_size` | `32` | Scans per batch |
|
| 66 |
+
| `optim_lr` | `2e-4` | Learning rate |
|
| 67 |
+
| `num_seeds` | `20` | Number of random seeds for CI |
|
| 68 |
+
| `val_every` | `1` | Validation frequency |
|
| 69 |
+
| `dry_run` | `False` | Quick test mode |
|
| 70 |
+
|
| 71 |
+
## Preprocessing Parameters
|
| 72 |
+
|
| 73 |
+
| Parameter | Default | Description |
|
| 74 |
+
|-----------|---------|-------------|
|
| 75 |
+
| `config` | β | Path to YAML config file |
|
| 76 |
+
| `steps` | β | Steps to execute (required, one or more) |
|
| 77 |
+
| `t2_dir` | β | Directory of T2W images |
|
| 78 |
+
| `dwi_dir` | β | Directory of DWI images |
|
| 79 |
+
| `adc_dir` | β | Directory of ADC images |
|
| 80 |
+
| `seg_dir` | β | Directory of segmentation masks |
|
| 81 |
+
| `output_dir` | β | Output directory |
|
| 82 |
+
| `margin` | `0.2` | Center-crop margin fraction |
|
| 83 |
+
| `project_dir` | β | Project root (for reference images and models) |
|
| 84 |
+
|
| 85 |
+
## Example YAML
|
| 86 |
+
|
| 87 |
+
=== "PI-RADS Training"
|
| 88 |
+
|
| 89 |
+
```yaml
|
| 90 |
+
data_root: /path/to/registered/t2_hist_matched
|
| 91 |
+
dataset_json: /path/to/PI-RADS_data.json
|
| 92 |
+
num_classes: 4
|
| 93 |
+
mil_mode: att_trans
|
| 94 |
+
tile_count: 24
|
| 95 |
+
tile_size: 64
|
| 96 |
+
depth: 3
|
| 97 |
+
use_heatmap: true
|
| 98 |
+
workers: 4
|
| 99 |
+
epochs: 100
|
| 100 |
+
batch_size: 8
|
| 101 |
+
optim_lr: 2e-4
|
| 102 |
+
weight_decay: 1e-5
|
| 103 |
+
amp: true
|
| 104 |
+
wandb: true
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
=== "csPCa Training"
|
| 108 |
+
|
| 109 |
+
```yaml
|
| 110 |
+
data_root: /path/to/registered/t2_hist_matched
|
| 111 |
+
dataset_json: /path/to/csPCa_data.json
|
| 112 |
+
num_classes: 4
|
| 113 |
+
mil_mode: att_trans
|
| 114 |
+
tile_count: 24
|
| 115 |
+
tile_size: 64
|
| 116 |
+
depth: 3
|
| 117 |
+
use_heatmap: true
|
| 118 |
+
workers: 6
|
| 119 |
+
checkpoint_pirads: /path/to/models/pirads.pt
|
| 120 |
+
epochs: 80
|
| 121 |
+
batch_size: 8
|
| 122 |
+
optim_lr: 2e-4
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
=== "Preprocessing"
|
| 126 |
+
|
| 127 |
+
```yaml
|
| 128 |
+
t2_dir: /path/to/raw/t2
|
| 129 |
+
dwi_dir: /path/to/raw/dwi
|
| 130 |
+
adc_dir: /path/to/raw/adc
|
| 131 |
+
output_dir: /path/to/processed
|
| 132 |
+
project_dir: /path/to/WSAttention-Prostate
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## Dry-Run Mode
|
| 136 |
+
|
| 137 |
+
The `--dry_run` flag configures a minimal run for quick testing:
|
| 138 |
+
|
| 139 |
+
- Epochs: 2
|
| 140 |
+
- Batch size: 2
|
| 141 |
+
- Workers: 0
|
| 142 |
+
- Seeds: 2
|
| 143 |
+
- W&B: disabled
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
python run_pirads.py --mode train --config config/config_pirads_train.yaml --dry_run
|
| 147 |
+
```
|
docs/contributing.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing
|
| 2 |
+
|
| 3 |
+
## Running Tests
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
# Full test suite
|
| 7 |
+
pytest tests/
|
| 8 |
+
|
| 9 |
+
# Single test
|
| 10 |
+
pytest tests/test_run.py::test_run_pirads_training
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
Tests run in `--dry_run` mode (2 epochs, batch_size=2, no W&B logging).
|
| 14 |
+
|
| 15 |
+
## Linting
|
| 16 |
+
|
| 17 |
+
This project uses [Ruff](https://docs.astral.sh/ruff/) for linting and formatting:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
# Check for lint errors
|
| 21 |
+
ruff check .
|
| 22 |
+
|
| 23 |
+
# Auto-format code
|
| 24 |
+
ruff format .
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
**Ruff configuration** (from `pyproject.toml`):
|
| 28 |
+
|
| 29 |
+
| Setting | Value |
|
| 30 |
+
|---------|-------|
|
| 31 |
+
| Line length | 100 |
|
| 32 |
+
| Quote style | Double quotes |
|
| 33 |
+
| Rules | E (errors), W (warnings) |
|
| 34 |
+
| Ignored | E501 (line too long) |
|
| 35 |
+
|
| 36 |
+
## SLURM Job Scripts
|
| 37 |
+
|
| 38 |
+
Job scripts are in `job_scripts/` and are configured for GPU partitions:
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
sbatch job_scripts/train_pirads.sh
|
| 42 |
+
sbatch job_scripts/train_cspca.sh
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Key SLURM settings used:
|
| 46 |
+
|
| 47 |
+
| Setting | Value |
|
| 48 |
+
|---------|-------|
|
| 49 |
+
| Partition | `gpu` |
|
| 50 |
+
| Memory | 128 GB |
|
| 51 |
+
| GPUs | 1 |
|
| 52 |
+
| Time limit | 48 hours |
|
| 53 |
+
|
| 54 |
+
!!! tip
|
| 55 |
+
The SLURM job name (`--job-name`) automatically becomes the `run_name`, which determines the log directory at `logs/<run_name>/`.
|
| 56 |
+
|
| 57 |
+
## Project Conventions
|
| 58 |
+
|
| 59 |
+
- **Configs** are stored in `config/` as YAML files
|
| 60 |
+
- **Logs** are written to `logs/<run_name>/` including TensorBoard events and training logs
|
| 61 |
+
- **Models** are saved to `logs/<run_name>/` during training; best models are saved to `models/` for deployment
|
| 62 |
+
- **Cache** is stored at `logs/<run_name>/cache/` and cleaned up automatically after training
|
docs/getting-started.md
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Getting Started
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
|
| 5 |
+
- Python 3.11+
|
| 6 |
+
- NVIDIA GPU recommended (CUDA-compatible)
|
| 7 |
+
- ~128 GB RAM for training (configurable via batch size)
|
| 8 |
+
|
| 9 |
+
## Installation
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
git clone https://github.com/ai-assisted-healthcare/WSAttention-Prostate.git
|
| 13 |
+
cd WSAttention-Prostate
|
| 14 |
+
pip install -r requirements.txt
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### External Git Dependencies
|
| 18 |
+
|
| 19 |
+
Two packages are installed directly from GitHub repositories:
|
| 20 |
+
|
| 21 |
+
| Package | Source | Purpose |
|
| 22 |
+
|---------|--------|---------|
|
| 23 |
+
| `AIAH_utility` | `ai-assisted-healthcare/AIAH_utility` | Healthcare imaging utilities |
|
| 24 |
+
| `grad-cam` | `jacobgil/pytorch-grad-cam` | Gradient-weighted class activation maps |
|
| 25 |
+
|
| 26 |
+
These are included in `requirements.txt` and install automatically.
|
| 27 |
+
|
| 28 |
+
## Verify Installation
|
| 29 |
+
|
| 30 |
+
Run the test suite in dry-run mode:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
pytest tests/
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Tests use `--dry_run` mode internally (2 epochs, batch_size=2, no W&B).
|
| 37 |
+
|
| 38 |
+
## Data Format
|
| 39 |
+
|
| 40 |
+
Input MRI scans should be in **NRRD** or **NIfTI** format with three modalities per patient:
|
| 41 |
+
|
| 42 |
+
- T2-weighted (T2W)
|
| 43 |
+
- Diffusion-weighted imaging (DWI)
|
| 44 |
+
- Apparent diffusion coefficient (ADC)
|
| 45 |
+
|
| 46 |
+
### Dataset JSON Structure
|
| 47 |
+
|
| 48 |
+
The data pipeline uses MONAI's decathlon-format JSON:
|
| 49 |
+
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"train": [
|
| 53 |
+
{
|
| 54 |
+
"image": "path/to/t2.nrrd",
|
| 55 |
+
"dwi": "path/to/dwi.nrrd",
|
| 56 |
+
"adc": "path/to/adc.nrrd",
|
| 57 |
+
"mask": "path/to/prostate_mask.nrrd",
|
| 58 |
+
"heatmap": "path/to/heatmap.nrrd",
|
| 59 |
+
"label": 0
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"test": [
|
| 63 |
+
...
|
| 64 |
+
]
|
| 65 |
+
}
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
The `image` key points to the T2W image, which serves as the reference modality. Labels for PI-RADS are 0-indexed: label `0` = PI-RADS 2, label `3` = PI-RADS 5. For csPCa, labels are binary (0 or 1).
|
| 69 |
+
|
| 70 |
+
## Project Structure
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
WSAttention-Prostate/
|
| 74 |
+
βββ run_pirads.py # PI-RADS training/testing entry point
|
| 75 |
+
βββ run_cspca.py # csPCa training/testing entry point
|
| 76 |
+
βββ run_inference.py # Full inference pipeline
|
| 77 |
+
βββ preprocess_main.py # Preprocessing entry point
|
| 78 |
+
βββ config/ # YAML configuration files
|
| 79 |
+
β βββ config_pirads_train.yaml
|
| 80 |
+
β βββ config_pirads_test.yaml
|
| 81 |
+
β βββ config_cspca_train.yaml
|
| 82 |
+
β βββ config_cspca_test.yaml
|
| 83 |
+
β βββ config_preprocess.yaml
|
| 84 |
+
βββ src/
|
| 85 |
+
β βββ model/
|
| 86 |
+
β β βββ MIL.py # MILModel_3D β core MIL architecture
|
| 87 |
+
β β βββ csPCa_model.py # csPCa_Model + SimpleNN head
|
| 88 |
+
β βββ data/
|
| 89 |
+
β β βββ data_loader.py # MONAI data pipeline
|
| 90 |
+
β β βββ custom_transforms.py
|
| 91 |
+
β βββ train/
|
| 92 |
+
β β βββ train_pirads.py # PI-RADS training loop
|
| 93 |
+
β β βββ train_cspca.py # csPCa training loop
|
| 94 |
+
β βββ preprocessing/
|
| 95 |
+
β β βββ register_and_crop.py
|
| 96 |
+
β β βββ prostate_mask.py
|
| 97 |
+
β β βββ histogram_match.py
|
| 98 |
+
β β βββ generate_heatmap.py
|
| 99 |
+
β βββ utils.py
|
| 100 |
+
βββ job_scripts/ # SLURM job templates
|
| 101 |
+
βββ tests/
|
| 102 |
+
βββ dataset/ # Reference images for histogram matching
|
| 103 |
+
βββ models/ # Pre-trained model checkpoints
|
| 104 |
+
```
|
docs/index.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div style="text-align: center; margin-bottom: 2em;">
|
| 2 |
+
<img src="assets/logo.svg" alt="WSAttention-Prostate Logo" width="240">
|
| 3 |
+
</div>
|
| 4 |
+
|
| 5 |
+
# WSAttention-Prostate
|
| 6 |
+
|
| 7 |
+
**Weakly-supervised attention-based 3D Multiple Instance Learning for prostate cancer risk prediction on multiparametric MRI.**
|
| 8 |
+
|
| 9 |
+
WSAttention-Prostate is a two-stage deep learning pipeline that predicts clinically significant prostate cancer (csPCa) risk from T2-weighted, DWI, and ADC MRI sequences. It uses 3D patch-based Multiple Instance Learning with transformer attention to first classify PI-RADS scores, then predict csPCa risk β all without requiring lesion-level annotations.
|
| 10 |
+
|
| 11 |
+
## Key Features
|
| 12 |
+
|
| 13 |
+
- **Weakly-supervised attention** β Heatmap-guided patch sampling and cosine-similarity attention loss replace the need for voxel-level labels
|
| 14 |
+
- **3D Multiple Instance Learning** β Extracts volumetric patches from MRI scans and aggregates them via transformer + attention pooling
|
| 15 |
+
- **Two-stage pipeline** β Stage 1 trains a 4-class PI-RADS classifier; Stage 2 freezes its backbone and trains a binary csPCa head
|
| 16 |
+
- **Multi-seed confidence intervals** β Runs 20 random seeds and reports 95% CI on AUC, sensitivity, and specificity
|
| 17 |
+
- **End-to-end preprocessing** β Registration, segmentation, histogram matching, and heatmap generation in a single configurable pipeline
|
| 18 |
+
|
| 19 |
+
## Pipeline Overview
|
| 20 |
+
|
| 21 |
+
```mermaid
|
| 22 |
+
flowchart LR
|
| 23 |
+
A[Raw MRI\nT2 + DWI + ADC] --> B[Preprocessing]
|
| 24 |
+
B --> C[Stage 1:\nPI-RADS Classification]
|
| 25 |
+
C --> D[Stage 2:\ncsPCa Prediction]
|
| 26 |
+
D --> E[Risk Score\n+ Top-5 Patches]
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Quick Links
|
| 30 |
+
|
| 31 |
+
- [Getting Started](getting-started.md) β Installation and first run
|
| 32 |
+
- [Pipeline](pipeline.md) β Full walkthrough of preprocessing, training, and evaluation
|
| 33 |
+
- [Architecture](architecture.md) β Model design and tensor shapes
|
| 34 |
+
- [Configuration](configuration.md) β YAML config reference
|
| 35 |
+
- [Inference](inference.md) β Running predictions on new data
|
docs/inference.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Inference
|
| 2 |
+
|
| 3 |
+
## Full Pipeline
|
| 4 |
+
|
| 5 |
+
`run_inference.py` runs the complete pipeline: preprocessing followed by PI-RADS classification and csPCa risk prediction.
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python run_inference.py --config config/config_preprocess.yaml
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
This script:
|
| 12 |
+
|
| 13 |
+
1. Runs all four preprocessing steps (register, segment, histogram match, heatmap)
|
| 14 |
+
2. Loads the PI-RADS model from `models/pirads.pt`
|
| 15 |
+
3. Loads the csPCa model from `models/cspca_model.pth`
|
| 16 |
+
4. For each scan: predicts PI-RADS score, csPCa risk probability, and identifies the top-5 most-attended patches
|
| 17 |
+
|
| 18 |
+
### Required Model Files
|
| 19 |
+
|
| 20 |
+
Place these in the `models/` directory:
|
| 21 |
+
|
| 22 |
+
| File | Description |
|
| 23 |
+
|------|-------------|
|
| 24 |
+
| `pirads.pt` | Trained PI-RADS MIL model checkpoint |
|
| 25 |
+
| `cspca_model.pth` | Trained csPCa model checkpoint |
|
| 26 |
+
| `prostate_segmentation_model.pt` | Pre-trained prostate segmentation model |
|
| 27 |
+
|
| 28 |
+
### Output Format
|
| 29 |
+
|
| 30 |
+
Results are saved to `<output_dir>/results.json`:
|
| 31 |
+
|
| 32 |
+
```json
|
| 33 |
+
{
|
| 34 |
+
"patient_001.nrrd": {
|
| 35 |
+
"Predicted PIRAD Score": 4.0,
|
| 36 |
+
"csPCa risk": 0.8234,
|
| 37 |
+
"Top left coordinate of top 5 patches(x,y,z)": [
|
| 38 |
+
[32, 45, 7],
|
| 39 |
+
[28, 50, 7],
|
| 40 |
+
[35, 42, 8],
|
| 41 |
+
[30, 48, 6],
|
| 42 |
+
[33, 44, 8]
|
| 43 |
+
]
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Label Mapping
|
| 49 |
+
|
| 50 |
+
PI-RADS predictions are 0-indexed internally and shifted by +2 for display:
|
| 51 |
+
|
| 52 |
+
| Internal Label | PI-RADS Score |
|
| 53 |
+
|---------------|---------------|
|
| 54 |
+
| 0 | PI-RADS 2 |
|
| 55 |
+
| 1 | PI-RADS 3 |
|
| 56 |
+
| 2 | PI-RADS 4 |
|
| 57 |
+
| 3 | PI-RADS 5 |
|
| 58 |
+
|
| 59 |
+
csPCa risk is a continuous probability in [0, 1].
|
| 60 |
+
|
| 61 |
+
## Testing Individual Models
|
| 62 |
+
|
| 63 |
+
### PI-RADS Testing
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python run_pirads.py --mode test \
|
| 67 |
+
--config config/config_pirads_test.yaml \
|
| 68 |
+
--checkpoint models/pirads.pt
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Reports Quadratic Weighted Kappa (QWK) across multiple seeds.
|
| 72 |
+
|
| 73 |
+
### csPCa Testing
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python run_cspca.py --mode test \
|
| 77 |
+
--config config/config_cspca_test.yaml \
|
| 78 |
+
--checkpoint_cspca models/cspca_model.pth
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Reports AUC, sensitivity, and specificity with 95% confidence intervals across 20 seeds (default).
|
docs/pipeline.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pipeline
|
| 2 |
+
|
| 3 |
+
The full pipeline has three phases: preprocessing, PI-RADS training (Stage 1), and csPCa training (Stage 2).
|
| 4 |
+
|
| 5 |
+
```mermaid
|
| 6 |
+
flowchart TD
|
| 7 |
+
subgraph Preprocessing
|
| 8 |
+
R[register_and_crop] --> S[get_segmentation_mask]
|
| 9 |
+
S --> H[histogram_match]
|
| 10 |
+
H --> G[get_heatmap]
|
| 11 |
+
end
|
| 12 |
+
|
| 13 |
+
subgraph Stage 1
|
| 14 |
+
P[PI-RADS Training\nCrossEntropy + Attention Loss]
|
| 15 |
+
end
|
| 16 |
+
|
| 17 |
+
subgraph Stage 2
|
| 18 |
+
C[csPCa Training\nFrozen Backbone + BCE Loss]
|
| 19 |
+
end
|
| 20 |
+
|
| 21 |
+
G --> P
|
| 22 |
+
P -->|frozen backbone| C
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Preprocessing
|
| 26 |
+
|
| 27 |
+
Run all four steps in sequence:
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
python preprocess_main.py \
|
| 31 |
+
--config config/config_preprocess.yaml \
|
| 32 |
+
--steps register_and_crop get_segmentation_mask histogram_match get_heatmap
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### Step 1: Register and Crop
|
| 36 |
+
|
| 37 |
+
Resamples T2, DWI, and ADC to a common spacing of `(0.4, 0.4, 3.0)` mm using `picai_prep`, then center-crops with a configurable margin (default 20%).
|
| 38 |
+
|
| 39 |
+
### Step 2: Prostate Segmentation
|
| 40 |
+
|
| 41 |
+
Runs a pre-trained segmentation model on T2W images to generate binary prostate masks. Post-processing retains only the top 10 slices by non-zero voxel count.
|
| 42 |
+
|
| 43 |
+
### Step 3: Histogram Matching
|
| 44 |
+
|
| 45 |
+
Matches the intensity histogram of each modality to a reference image within masked (prostate) regions using `skimage.exposure.match_histograms`.
|
| 46 |
+
|
| 47 |
+
### Step 4: Heatmap Generation
|
| 48 |
+
|
| 49 |
+
Creates attention heatmaps from DWI and ADC:
|
| 50 |
+
|
| 51 |
+
- **DWI heatmap**: `(dwi - min) / (max - min)` β higher DWI signal = higher attention
|
| 52 |
+
- **ADC heatmap**: `(max - adc) / (max - min)` β lower ADC = higher attention (inverted)
|
| 53 |
+
- **Combined**: element-wise product, re-normalized to [0, 1]
|
| 54 |
+
|
| 55 |
+
!!! note "Step Dependencies"
|
| 56 |
+
Steps must run in the order shown above. The pipeline validates dependencies automatically β for example, `get_heatmap` requires `get_segmentation_mask` and `histogram_match` to have run first.
|
| 57 |
+
|
| 58 |
+
## Stage 1: PI-RADS Classification
|
| 59 |
+
|
| 60 |
+
Trains a 4-class PI-RADS classifier (grades 2β5, mapped to labels 0β3).
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
python run_pirads.py --mode train --config config/config_pirads_train.yaml
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
**Training details:**
|
| 67 |
+
|
| 68 |
+
| Component | Value |
|
| 69 |
+
|-----------|-------|
|
| 70 |
+
| Loss | CrossEntropy + cosine-similarity attention loss |
|
| 71 |
+
| Attention loss weight | Linear warmup over 25 epochs to `lambda=2.0` |
|
| 72 |
+
| Optimizer | AdamW (base LR `3e-5`, transformer LR `6e-5`) |
|
| 73 |
+
| Scheduler | CosineAnnealingLR |
|
| 74 |
+
| Metric | Quadratic Weighted Kappa (QWK) |
|
| 75 |
+
| Early stopping | After 40 epochs without validation loss improvement |
|
| 76 |
+
| AMP | Disabled by default (enabled in example YAML config) |
|
| 77 |
+
|
| 78 |
+
**Attention loss**: For each batch, the model's learned attention weights are compared against heatmap-derived attention labels via cosine similarity. PI-RADS 2 samples receive uniform attention (no lesion expected). The loss is weighted by `lambda_att`, which warms up linearly over the first 25 epochs.
|
| 79 |
+
|
| 80 |
+
## Stage 2: csPCa Risk Prediction
|
| 81 |
+
|
| 82 |
+
Builds on a frozen PI-RADS backbone to predict binary csPCa risk.
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
python run_cspca.py --mode train --config config/config_cspca_train.yaml
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
**Training details:**
|
| 89 |
+
|
| 90 |
+
| Component | Value |
|
| 91 |
+
|-----------|-------|
|
| 92 |
+
| Loss | Binary Cross-Entropy (BCE) |
|
| 93 |
+
| Backbone | Frozen PI-RADS model (ResNet18 + Transformer); attention module is trainable |
|
| 94 |
+
| Head | SimpleNN: `512 β 256 β 128 β 1` with ReLU + Dropout(0.3) + Sigmoid |
|
| 95 |
+
| Optimizer | AdamW (LR `2e-4`) |
|
| 96 |
+
| Seeds | 20 random seeds (default) for 95% CI |
|
| 97 |
+
| Metrics | AUC, Sensitivity, Specificity |
|
| 98 |
+
|
| 99 |
+
The backbone's feature extractor (`net`), transformer, and `myfc` are frozen. The attention module and `SimpleNN` classification head are trained. After training across all seeds, the framework reports mean and 95% confidence intervals for AUC, sensitivity, and specificity.
|
mkdocs.yml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
site_name: WSAttention-Prostate
|
| 2 |
+
site_description: Weakly-supervised attention-based 3D MIL for prostate cancer risk prediction on multiparametric MRI
|
| 3 |
+
repo_url: https://github.com/ai-assisted-healthcare/WSAttention-Prostate
|
| 4 |
+
repo_name: WSAttention-Prostate
|
| 5 |
+
|
| 6 |
+
theme:
|
| 7 |
+
name: material
|
| 8 |
+
logo: assets/logo.svg
|
| 9 |
+
favicon: assets/logo.svg
|
| 10 |
+
palette:
|
| 11 |
+
- scheme: default
|
| 12 |
+
primary: indigo
|
| 13 |
+
accent: teal
|
| 14 |
+
toggle:
|
| 15 |
+
icon: material/brightness-7
|
| 16 |
+
name: Switch to dark mode
|
| 17 |
+
- scheme: slate
|
| 18 |
+
primary: indigo
|
| 19 |
+
accent: teal
|
| 20 |
+
toggle:
|
| 21 |
+
icon: material/brightness-4
|
| 22 |
+
name: Switch to light mode
|
| 23 |
+
font:
|
| 24 |
+
text: Inter
|
| 25 |
+
code: JetBrains Mono
|
| 26 |
+
features:
|
| 27 |
+
- navigation.instant
|
| 28 |
+
- navigation.sections
|
| 29 |
+
- navigation.top
|
| 30 |
+
- search.suggest
|
| 31 |
+
- search.highlight
|
| 32 |
+
- content.code.copy
|
| 33 |
+
- content.tabs.link
|
| 34 |
+
|
| 35 |
+
plugins:
|
| 36 |
+
- search
|
| 37 |
+
# TODO: add mkdocstrings[python] plugin for autodoc support
|
| 38 |
+
|
| 39 |
+
markdown_extensions:
|
| 40 |
+
- admonition
|
| 41 |
+
- pymdownx.details
|
| 42 |
+
- pymdownx.superfences:
|
| 43 |
+
custom_fences:
|
| 44 |
+
- name: mermaid
|
| 45 |
+
class: mermaid
|
| 46 |
+
format: !!python/name:pymdownx.superfences.fence_code_format
|
| 47 |
+
- pymdownx.highlight:
|
| 48 |
+
anchor_linenums: true
|
| 49 |
+
- pymdownx.inlinehilite
|
| 50 |
+
- pymdownx.tabbed:
|
| 51 |
+
alternate_style: true
|
| 52 |
+
- tables
|
| 53 |
+
- attr_list
|
| 54 |
+
- md_in_html
|
| 55 |
+
|
| 56 |
+
nav:
|
| 57 |
+
- Home: index.md
|
| 58 |
+
- Getting Started: getting-started.md
|
| 59 |
+
- Pipeline: pipeline.md
|
| 60 |
+
- Architecture: architecture.md
|
| 61 |
+
- Configuration: configuration.md
|
| 62 |
+
- Inference: inference.md
|
| 63 |
+
- API Reference:
|
| 64 |
+
- Models: api/models.md
|
| 65 |
+
- Preprocessing: api/preprocessing.md
|
| 66 |
+
- Data Loading: api/data.md
|
| 67 |
+
- Contributing: contributing.md
|