kbressem commited on
Commit
cd89698
Β·
1 Parent(s): 1baebae

Add documentation site

Browse files

Add full project documentation and tooling.

.github/workflows/docs.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy Documentation
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ paths:
7
+ - 'docs/**'
8
+ - 'mkdocs.yml'
9
+ workflow_dispatch:
10
+
11
+ permissions:
12
+ contents: write
13
+
14
+ jobs:
15
+ deploy:
16
+ runs-on: ubuntu-latest
17
+ steps:
18
+ - uses: actions/checkout@v4
19
+
20
+ - uses: actions/setup-python@v5
21
+ with:
22
+ python-version: '3.11'
23
+
24
+ - run: pip install mkdocs-material
25
+
26
+ - run: mkdocs gh-deploy --force
.gitignore CHANGED
@@ -6,4 +6,5 @@ temp.ipynb
6
  __pycache__/
7
  **/__pycache__/
8
  *.pyc
9
- .ruff_cache
 
 
6
  __pycache__/
7
  **/__pycache__/
8
  *.pyc
9
+ .ruff_cache
10
+ site/
README.md CHANGED
@@ -1 +1,122 @@
1
- # WSAttention-Prostate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="docs/assets/logo.svg" alt="WSAttention-Prostate Logo" width="240">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <img src="https://img.shields.io/badge/python-3.11-blue?logo=python&logoColor=white" alt="Python 3.11">
7
+ <img src="https://img.shields.io/badge/pytorch-2.5-ee4c2c?logo=pytorch&logoColor=white" alt="PyTorch 2.5">
8
+ <img src="https://img.shields.io/badge/MONAI-1.4-3ddc84" alt="MONAI 1.4">
9
+ <img src="https://img.shields.io/badge/license-Apache%202.0-green" alt="License">
10
+ <a href="https://ai-assisted-healthcare.github.io/WSAttention-Prostate/"><img src="https://img.shields.io/badge/docs-mkdocs-blue" alt="Docs"></a>
11
+ </p>
12
+
13
+ # WSAttention-Prostate
14
+
15
+ **Weakly-supervised attention-based 3D Multiple Instance Learning for prostate cancer risk prediction on multiparametric MRI.**
16
+
17
+ WSAttention-Prostate is a two-stage deep learning pipeline that predicts clinically significant prostate cancer (csPCa) risk from T2-weighted, DWI, and ADC MRI sequences. It uses 3D patch-based Multiple Instance Learning with transformer attention to first classify PI-RADS scores, then predict csPCa risk β€” all without requiring lesion-level annotations.
18
+
19
+ ## Key Features
20
+
21
+ - **Weakly-supervised attention** β€” Heatmap-guided patch sampling and cosine-similarity attention loss replace the need for voxel-level labels
22
+ - **3D Multiple Instance Learning** β€” Extracts volumetric patches from MRI scans and aggregates them via transformer + attention pooling
23
+ - **Two-stage pipeline** β€” Stage 1 trains a 4-class PI-RADS classifier; Stage 2 freezes its backbone and trains a binary csPCa head
24
+ - **Multi-seed confidence intervals** β€” Runs 20 random seeds and reports 95% CI on AUC, sensitivity, and specificity
25
+ - **End-to-end preprocessing** β€” Registration, segmentation, histogram matching, and heatmap generation in a single configurable pipeline
26
+
27
+ ## Pipeline Overview
28
+
29
+ ```mermaid
30
+ flowchart LR
31
+ A[Raw MRI\nT2 + DWI + ADC] --> B[Preprocessing]
32
+ B --> C[Stage 1:\nPI-RADS Classification]
33
+ C --> D[Stage 2:\ncsPCa Prediction]
34
+ D --> E[Risk Score\n+ Top-5 Patches]
35
+ ```
36
+
37
+ ## Quick Start
38
+
39
+ ```bash
40
+ git clone https://github.com/ai-assisted-healthcare/WSAttention-Prostate.git
41
+ cd WSAttention-Prostate
42
+ pip install -r requirements.txt
43
+ pytest tests/
44
+ ```
45
+
46
+ ## Usage
47
+
48
+ ### Preprocessing
49
+
50
+ ```bash
51
+ python preprocess_main.py --config config/config_preprocess.yaml \
52
+ --steps register_and_crop get_segmentation_mask histogram_match get_heatmap
53
+ ```
54
+
55
+ ### PI-RADS Training
56
+
57
+ ```bash
58
+ python run_pirads.py --mode train --config config/config_pirads_train.yaml
59
+ ```
60
+
61
+ ### csPCa Training
62
+
63
+ ```bash
64
+ python run_cspca.py --mode train --config config/config_cspca_train.yaml
65
+ ```
66
+
67
+ ### Inference
68
+
69
+ ```bash
70
+ python run_pirads.py --mode test --config config/config_pirads_test.yaml --checkpoint <path>
71
+ python run_cspca.py --mode test --config config/config_cspca_test.yaml --checkpoint_cspca <path>
72
+ python run_inference.py --config config/config_preprocess.yaml
73
+ ```
74
+
75
+ See the [full documentation](https://ai-assisted-healthcare.github.io/WSAttention-Prostate/) for detailed configuration options and data format requirements.
76
+
77
+ ## Project Structure
78
+
79
+ ```
80
+ WSAttention-Prostate/
81
+ β”œβ”€β”€ run_pirads.py # PI-RADS training/testing entry point
82
+ β”œβ”€β”€ run_cspca.py # csPCa training/testing entry point
83
+ β”œβ”€β”€ run_inference.py # Full inference pipeline
84
+ β”œβ”€β”€ preprocess_main.py # Preprocessing entry point
85
+ β”œβ”€β”€ config/ # YAML configuration files
86
+ β”œβ”€β”€ src/
87
+ β”‚ β”œβ”€β”€ model/
88
+ β”‚ β”‚ β”œβ”€β”€ MIL.py # MILModel_3D β€” core MIL architecture
89
+ β”‚ β”‚ └── csPCa_model.py # csPCa_Model + SimpleNN head
90
+ β”‚ β”œβ”€β”€ data/
91
+ β”‚ β”‚ β”œβ”€β”€ data_loader.py # MONAI data pipeline
92
+ β”‚ β”‚ └── custom_transforms.py
93
+ β”‚ β”œβ”€β”€ train/
94
+ β”‚ β”‚ β”œβ”€β”€ train_pirads.py # PI-RADS training loop
95
+ β”‚ β”‚ └── train_cspca.py # csPCa training loop
96
+ β”‚ β”œβ”€β”€ preprocessing/ # Registration, segmentation, heatmaps
97
+ β”‚ └── utils.py # Shared utilities and step validation
98
+ β”œβ”€β”€ tests/
99
+ β”œβ”€β”€ dataset/ # Reference images for histogram matching
100
+ └── models/ # Downloaded checkpoints (not in repo)
101
+ ```
102
+
103
+ ## Architecture
104
+
105
+ Input MRI patches are processed independently through a 3D ResNet18 backbone, then aggregated via a transformer encoder and attention pooling:
106
+
107
+ ```mermaid
108
+ flowchart TD
109
+ A["Input [B, N, C, D, H, W]"] --> B["Reshape to [B*N, C, D, H, W]"]
110
+ B --> C[ResNet18-3D Backbone]
111
+ C --> D["Reshape to [B, N, 512]"]
112
+ D --> E[Transformer Encoder\n4 layers, 8 heads]
113
+ E --> F[Attention Pooling\n512 β†’ 2048 β†’ 1]
114
+ F --> G["Weighted Sum [B, 512]"]
115
+ G --> H["FC Head [B, num_classes]"]
116
+ ```
117
+
118
+ For csPCa prediction, the backbone is frozen and a 3-layer MLP (`512 β†’ 256 β†’ 128 β†’ 1`) replaces the classification head.
119
+
120
+ ## License
121
+
122
+ Apache-2.0 β€” see [LICENSE](LICENSE).
docs/api/data.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Loading Reference
2
+
3
+ ## get_dataloader
4
+
5
+ ```python
6
+ def get_dataloader(args, split: Literal["train", "test"]) -> DataLoader
7
+ ```
8
+
9
+ Creates a PyTorch DataLoader with MONAI transforms and persistent caching.
10
+
11
+ **Parameters:**
12
+
13
+ | Parameter | Description |
14
+ |-----------|-------------|
15
+ | `args` | Namespace with `dataset_json`, `data_root`, `tile_size`, `tile_count`, `depth`, `use_heatmap`, `batch_size`, `workers`, `dry_run`, `logdir` |
16
+ | `split` | `"train"` or `"test"` |
17
+
18
+ **Behavior:**
19
+
20
+ - Loads data lists from a MONAI decathlon-format JSON
21
+ - In `dry_run` mode, limits to 8 samples
22
+ - Uses `PersistentDataset` with cache stored at `<logdir>/cache/<split>/`
23
+ - Training split is shuffled; test split is not
24
+ - Uses `list_data_collate` to stack patches into `[B, N, C, D, H, W]`
25
+
26
+ ## Transform Pipeline
27
+
28
+ Two variants depending on `args.use_heatmap`:
29
+
30
+ ### With Heatmaps (default)
31
+
32
+ | Step | Transform | Description |
33
+ |------|-----------|-------------|
34
+ | 1 | `LoadImaged` | Load T2, mask, DWI, ADC, heatmap (ITKReader, channel-first) |
35
+ | 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
36
+ | 3 | `ConcatItemsd` | Stack T2 + DWI + ADC β†’ 3-channel image |
37
+ | 4 | `NormalizeIntensity_customd` | Z-score normalize per channel using mask-only statistics |
38
+ | 5 | `ElementwiseProductd` | Multiply mask * heatmap β†’ `final_heatmap` |
39
+ | 6 | `RandWeightedCropd` | Extract N patches weighted by `final_heatmap` |
40
+ | 7 | `EnsureTyped` | Cast labels to float32 |
41
+ | 8 | `Transposed` | Reorder image dims for 3D convolution |
42
+ | 9 | `DeleteItemsd` | Remove intermediate keys (mask, dwi, adc, heatmap) |
43
+ | 10 | `ToTensord` | Convert to PyTorch tensors |
44
+
45
+ ### Without Heatmaps
46
+
47
+ | Step | Transform | Description |
48
+ |------|-----------|-------------|
49
+ | 1 | `LoadImaged` | Load T2, mask, DWI, ADC |
50
+ | 2 | `ClipMaskIntensityPercentilesd` | Clip T2 intensity to [0, 99.5] percentiles within mask |
51
+ | 3 | `ConcatItemsd` | Stack T2 + DWI + ADC β†’ 3-channel image |
52
+ | 4 | `NormalizeIntensityd` | Standard channel-wise normalization (MONAI built-in) |
53
+ | 5 | `RandCropByPosNegLabeld` | Extract N patches from positive (mask) regions |
54
+ | 6 | `EnsureTyped` | Cast labels to float32 |
55
+ | 7 | `Transposed` | Reorder image dims |
56
+ | 8 | `DeleteItemsd` | Remove intermediate keys |
57
+ | 9 | `ToTensord` | Convert to tensors |
58
+
59
+ ## list_data_collate
60
+
61
+ ```python
62
+ def list_data_collate(batch: Sequence) -> dict
63
+ ```
64
+
65
+ Custom collation function that stacks per-patient patch lists into batch tensors.
66
+
67
+ Each sample from the dataset is a list of N patch dictionaries. This function:
68
+
69
+ 1. Stacks `image` across patches: `[N, C, D, H, W]` per sample
70
+ 2. Stacks `final_heatmap` if present
71
+ 3. Applies PyTorch's `default_collate` to form the batch dimension
72
+
73
+ Result: `{"image": [B, N, C, D, H, W], "label": [B], ...}`
74
+
75
+ ## Custom Transforms
76
+
77
+ ### ClipMaskIntensityPercentilesd
78
+
79
+ ```python
80
+ ClipMaskIntensityPercentilesd(
81
+ keys: KeysCollection,
82
+ mask_key: str,
83
+ lower: float | None,
84
+ upper: float | None,
85
+ sharpness_factor: float | None = None,
86
+ channel_wise: bool = False,
87
+ dtype: DtypeLike = np.float32,
88
+ )
89
+ ```
90
+
91
+ Clips image intensity to percentiles computed only from the **masked region**. Supports both hard clipping (default) and soft clipping (via `sharpness_factor`).
92
+
93
+ ### NormalizeIntensity_customd
94
+
95
+ ```python
96
+ NormalizeIntensity_customd(
97
+ keys: KeysCollection,
98
+ mask_key: str,
99
+ subtrahend: NdarrayOrTensor | None = None,
100
+ divisor: NdarrayOrTensor | None = None,
101
+ nonzero: bool = False,
102
+ channel_wise: bool = False,
103
+ dtype: DtypeLike = np.float32,
104
+ )
105
+ ```
106
+
107
+ Z-score normalization where mean and standard deviation are computed only from **masked voxels**. Supports channel-wise normalization.
108
+
109
+ ### ElementwiseProductd
110
+
111
+ ```python
112
+ ElementwiseProductd(
113
+ keys: KeysCollection,
114
+ output_key: str,
115
+ )
116
+ ```
117
+
118
+ Computes the element-wise product of two arrays from the data dictionary and stores the result in `output_key`. Used to combine the prostate mask with the attention heatmap.
119
+
120
+ ## Dataset JSON Format
121
+
122
+ The pipeline expects a MONAI decathlon-format JSON file:
123
+
124
+ ```json
125
+ {
126
+ "train": [
127
+ {
128
+ "image": "relative/path/to/t2.nrrd",
129
+ "dwi": "relative/path/to/dwi.nrrd",
130
+ "adc": "relative/path/to/adc.nrrd",
131
+ "mask": "relative/path/to/mask.nrrd",
132
+ "heatmap": "relative/path/to/heatmap.nrrd",
133
+ "label": 2
134
+ }
135
+ ],
136
+ "test": [...]
137
+ }
138
+ ```
139
+
140
+ Paths are relative to `data_root`. The `heatmap` key is only required when `use_heatmap=True`.
docs/api/models.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models Reference
2
+
3
+ ## MILModel_3D
4
+
5
+ ```python
6
+ class MILModel_3D(nn.Module):
7
+ def __init__(
8
+ self,
9
+ num_classes: int,
10
+ mil_mode: str = "att",
11
+ pretrained: bool = True,
12
+ backbone: str | nn.Module | None = None,
13
+ backbone_num_features: int | None = None,
14
+ trans_blocks: int = 4,
15
+ trans_dropout: float = 0.0,
16
+ )
17
+ ```
18
+
19
+ **Constructor arguments:**
20
+
21
+ | Argument | Type | Default | Description |
22
+ |----------|------|---------|-------------|
23
+ | `num_classes` | `int` | β€” | Number of output classes |
24
+ | `mil_mode` | `str` | `"att"` | MIL aggregation mode |
25
+ | `pretrained` | `bool` | `True` | Use pretrained backbone weights |
26
+ | `backbone` | `str \| nn.Module \| None` | `None` | Backbone CNN (None = ResNet18-3D) |
27
+ | `backbone_num_features` | `int \| None` | `None` | Output features of custom backbone |
28
+ | `trans_blocks` | `int` | `4` | Number of transformer encoder layers |
29
+ | `trans_dropout` | `float` | `0.0` | Transformer dropout rate |
30
+
31
+ **MIL modes:**
32
+
33
+ | Mode | Description |
34
+ |------|-------------|
35
+ | `mean` | Average logits across all patches β€” equivalent to pure CNN |
36
+ | `max` | Keep only the max-probability instance for loss |
37
+ | `att` | Attention-based MIL ([Ilse et al., 2018](https://arxiv.org/abs/1802.04712)) |
38
+ | `att_trans` | Transformer + attention MIL ([Shao et al., 2021](https://arxiv.org/abs/2111.01556)) |
39
+ | `att_trans_pyramid` | Pyramid transformer using intermediate ResNet layers |
40
+
41
+ **Key methods:**
42
+
43
+ - `forward(x, no_head=False)` β€” Full forward pass. If `no_head=True`, returns patch-level features `[B, N, 512]` before transformer and attention pooling (used during attention loss computation).
44
+ - `calc_head(x)` β€” Applies the MIL aggregation and classification head to patch features.
45
+
46
+ **Example:**
47
+
48
+ ```python
49
+ import torch
50
+ from src.model.MIL import MILModel_3D
51
+
52
+ model = MILModel_3D(num_classes=4, mil_mode="att_trans")
53
+ # Input: [batch, patches, channels, depth, height, width]
54
+ x = torch.randn(2, 24, 3, 3, 64, 64)
55
+ logits = model(x) # [2, 4]
56
+ ```
57
+
58
+ ## csPCa_Model
59
+
60
+ ```python
61
+ class csPCa_Model(nn.Module):
62
+ def __init__(self, backbone: nn.Module)
63
+ ```
64
+
65
+ Wraps a pre-trained `MILModel_3D` backbone for binary csPCa prediction. The backbone's feature extractor, transformer, and attention mechanism are reused. The original classification head (`myfc`) is replaced by a `SimpleNN`.
66
+
67
+ **Attributes:**
68
+
69
+ | Attribute | Type | Description |
70
+ |-----------|------|-------------|
71
+ | `backbone` | `MILModel_3D` | Frozen PI-RADS backbone |
72
+ | `fc_cspca` | `SimpleNN` | Binary classification head |
73
+ | `fc_dim` | `int` | Feature dimension (512 for ResNet18) |
74
+
75
+ **Example:**
76
+
77
+ ```python
78
+ import torch
79
+ from src.model.MIL import MILModel_3D
80
+ from src.model.csPCa_model import csPCa_Model
81
+
82
+ backbone = MILModel_3D(num_classes=4, mil_mode="att_trans")
83
+ model = csPCa_Model(backbone=backbone)
84
+
85
+ x = torch.randn(2, 24, 3, 3, 64, 64)
86
+ prob = model(x) # [2, 1] β€” sigmoid probabilities
87
+ ```
88
+
89
+ ## SimpleNN
90
+
91
+ ```python
92
+ class SimpleNN(nn.Module):
93
+ def __init__(self, input_dim: int)
94
+ ```
95
+
96
+ A lightweight MLP for binary classification:
97
+
98
+ ```
99
+ Linear(input_dim, 256) β†’ ReLU
100
+ Linear(256, 128) β†’ ReLU β†’ Dropout(0.3)
101
+ Linear(128, 1) β†’ Sigmoid
102
+ ```
103
+
104
+ Input: `[B, input_dim]` β€” Output: `[B, 1]` (probability).
docs/api/preprocessing.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Preprocessing Reference
2
+
3
+ ## Overview
4
+
5
+ Preprocessing is orchestrated by `preprocess_main.py`, which runs steps in sequence. Each step receives and returns the `args` namespace, updating directory paths as it goes.
6
+
7
+ ## Step Dependencies
8
+
9
+ | Step | Requires |
10
+ |------|----------|
11
+ | `register_and_crop` | β€” |
12
+ | `get_segmentation_mask` | `register_and_crop` |
13
+ | `histogram_match` | `register_and_crop`, `get_segmentation_mask` |
14
+ | `get_heatmap` | `register_and_crop`, `get_segmentation_mask`, `histogram_match` |
15
+
16
+ Dependencies are validated at runtime β€” the pipeline will exit with an error if steps are out of order.
17
+
18
+ ## register_files
19
+
20
+ ```python
21
+ def register_files(args) -> args
22
+ ```
23
+
24
+ Registers and crops T2, DWI, and ADC images to a standardized spacing and size.
25
+
26
+ **Process:**
27
+
28
+ 1. Reads images from `args.t2_dir`, `args.dwi_dir`, `args.adc_dir`
29
+ 2. Resamples to spacing `(0.4, 0.4, 3.0)` mm using `picai_prep.Sample`
30
+ 3. Center-crops with `args.margin` (default 0.2) in x/y dimensions
31
+ 4. Saves to `<output_dir>/t2_registered/`, `DWI_registered/`, `ADC_registered/`
32
+
33
+ **Updates `args`:** `t2_dir`, `dwi_dir`, `adc_dir` β†’ registered directories.
34
+
35
+ ## get_segmask
36
+
37
+ ```python
38
+ def get_segmask(args) -> args
39
+ ```
40
+
41
+ Generates prostate segmentation masks from T2W images using a pre-trained model.
42
+
43
+ **Process:**
44
+
45
+ 1. Loads model config from `<project_dir>/config/inference.json`
46
+ 2. Loads checkpoint from `<project_dir>/models/prostate_segmentation_model.pt`
47
+ 3. Applies MONAI transforms: orientation (RAS), spacing (0.5 mm isotropic), intensity normalization
48
+ 4. Runs inference and inverts transforms to original space
49
+ 5. Post-processes: retains only top 10 slices by non-zero voxel count
50
+ 6. Saves NRRD masks to `<output_dir>/prostate_mask/`
51
+
52
+ **Updates `args`:** adds `seg_dir`.
53
+
54
+ ## histmatch
55
+
56
+ ```python
57
+ def histmatch(args) -> args
58
+ ```
59
+
60
+ Matches the intensity histogram of each modality to a reference image.
61
+
62
+ **Process:**
63
+
64
+ 1. Reads reference images from `<project_dir>/dataset/` (`t2_reference.nrrd`, `dwi_reference.nrrd`, `adc_reference.nrrd`, `prostate_segmentation_reference.nrrd`)
65
+ 2. For each patient, matches histograms within the prostate mask using `skimage.exposure.match_histograms`
66
+ 3. Saves to `<output_dir>/t2_histmatched/`, `DWI_histmatched/`, `ADC_histmatched/`
67
+
68
+ **Updates `args`:** `t2_dir`, `dwi_dir`, `adc_dir` β†’ histogram-matched directories.
69
+
70
+ ### get_histmatched
71
+
72
+ ```python
73
+ def get_histmatched(
74
+ data: np.ndarray,
75
+ ref_data: np.ndarray,
76
+ mask: np.ndarray,
77
+ ref_mask: np.ndarray,
78
+ ) -> np.ndarray
79
+ ```
80
+
81
+ Low-level function that performs histogram matching on masked regions only. Unmasked pixels remain unchanged.
82
+
83
+ ## get_heatmap
84
+
85
+ ```python
86
+ def get_heatmap(args) -> args
87
+ ```
88
+
89
+ Generates combined DWI/ADC attention heatmaps.
90
+
91
+ **Process:**
92
+
93
+ 1. For each file, reads DWI, ADC, and prostate mask
94
+ 2. Computes DWI heatmap: `(dwi - min) / (max - min)` within mask
95
+ 3. Computes ADC heatmap: `(max - adc) / (max - min)` within mask (inverted β€” low ADC = high attention)
96
+ 4. Combines via element-wise multiplication
97
+ 5. Re-normalizes to [0, 1]
98
+ 6. Saves to `<output_dir>/heatmaps/`
99
+
100
+ **Updates `args`:** adds `heatmapdir`.
101
+
102
+ !!! info "Edge cases"
103
+ If all values within the mask are identical for a modality (DWI or ADC), that modality's heatmap is skipped. If both are constant, the heatmap defaults to all ones.
docs/architecture.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Architecture
2
+
3
+ ## Tensor Shape Convention
4
+
5
+ Throughout the pipeline, tensors follow the shape `[B, N, C, D, H, W]`:
6
+
7
+ | Dim | Meaning | Typical Value |
8
+ |-----|---------|---------------|
9
+ | B | Batch size | 4–8 |
10
+ | N | Number of patches (instances) | 24 |
11
+ | C | Channels (T2 + DWI + ADC) | 3 |
12
+ | D | Depth (slices per patch) | 3 |
13
+ | H | Patch height | 64 |
14
+ | W | Patch width | 64 |
15
+
16
+ ## MILModel_3D
17
+
18
+ The core model processes each patch independently through a CNN backbone, then aggregates patch-level features via a transformer encoder and attention pooling.
19
+
20
+ ```mermaid
21
+ flowchart TD
22
+ A["Input [B, N, C, D, H, W]"] --> B["Reshape to [B*N, C, D, H, W]"]
23
+ B --> C[ResNet18-3D Backbone]
24
+ C --> D["Reshape to [B, N, 512]"]
25
+ D --> E[Transformer Encoder\n4 layers, 8 heads]
26
+ E --> F[Attention Pooling\n512 β†’ 2048 β†’ 1]
27
+ F --> G["Weighted Sum [B, 512]"]
28
+ G --> H["FC Head [B, num_classes]"]
29
+ ```
30
+
31
+ ### Forward Pass
32
+
33
+ 1. **Backbone**: Input is reshaped from `[B, N, C, D, H, W]` to `[B*N, C, D, H, W]` and passed through a 3D ResNet18 (with 3 input channels). The final FC layer is removed, yielding 512-dimensional features per patch.
34
+
35
+ 2. **Transformer**: Features are reshaped to `[B, N, 512]`, permuted to `[N, B, 512]` for the transformer encoder (4 layers, 8 attention heads), then permuted back.
36
+
37
+ 3. **Attention**: A two-layer attention network (`512 β†’ 2048 β†’ 1` with Tanh) computes a scalar weight per patch, normalized via softmax.
38
+
39
+ 4. **Classification**: The attention-weighted sum of patch features produces a single `[B, 512]` vector per scan, which is projected to class logits by a linear layer.
40
+
41
+ ### MIL Modes
42
+
43
+ | Mode | Aggregation Strategy |
44
+ |------|---------------------|
45
+ | `mean` | Average logits across patches |
46
+ | `max` | Max logits across patches |
47
+ | `att` | Attention-weighted feature pooling |
48
+ | `att_trans` | Transformer encoder + attention pooling (primary mode) |
49
+ | `att_trans_pyramid` | Pyramid transformer on intermediate ResNet layers + attention |
50
+
51
+ The default and primary mode is `att_trans`.
52
+
53
+ ## csPCa_Model
54
+
55
+ Wraps a frozen `MILModel_3D` backbone and replaces the classification head:
56
+
57
+ ```mermaid
58
+ flowchart TD
59
+ A["Input [B, N, C, D, H, W]"] --> B["Frozen Backbone\n(ResNet18 + Transformer)"]
60
+ B --> C["Pooled Features [B, 512]"]
61
+ C --> D["SimpleNN Head\n512 β†’ 256 β†’ 128 β†’ 1"]
62
+ D --> E["Sigmoid β†’ csPCa Probability"]
63
+ ```
64
+
65
+ ### SimpleNN
66
+
67
+ ```
68
+ Linear(512, 256) β†’ ReLU
69
+ Linear(256, 128) β†’ ReLU β†’ Dropout(0.3)
70
+ Linear(128, 1) β†’ Sigmoid
71
+ ```
72
+
73
+ During csPCa training, the backbone's `net` (ResNet18), `transformer`, and `myfc` parameters are frozen. The `attention` module and `SimpleNN` head remain trainable.
74
+
75
+ ## Attention Loss
76
+
77
+ During PI-RADS training with heatmaps enabled, the model uses a dual-loss objective:
78
+
79
+ ```
80
+ total_loss = class_loss + lambda_att * attention_loss
81
+ ```
82
+
83
+ - **Classification loss**: Standard CrossEntropy on PI-RADS labels
84
+ - **Attention loss**: `1 - cosine_similarity(predicted_attention, heatmap_attention)`
85
+ - Heatmap-derived attention labels are computed by summing spatial heatmap values per patch, squaring for sharpness, and normalizing
86
+ - PI-RADS 2 samples get uniform attention (no expected lesion)
87
+ - `lambda_att` warms up linearly from 0 to 2.0 over the first 25 epochs
88
+ - The attention predictions are computed with detached transformer outputs to avoid gradient interference with classification
89
+
90
+ ## Patch Extraction
91
+
92
+ Patches are extracted using MONAI's `RandWeightedCropd` (when heatmaps are available) or `RandCropByPosNegLabeld` (without heatmaps):
93
+
94
+ - **With heatmaps**: The combined DWI/ADC heatmap multiplied by the prostate mask serves as the sampling weight map β€” regions with high DWI and low ADC are sampled more frequently
95
+ - **Without heatmaps**: Crops are sampled from positive (prostate) regions based on the binary mask
96
+
97
+ Each scan yields `N` patches (default 24) of size `tile_size x tile_size x depth` (default 64x64x3).
docs/assets/logo.svg ADDED
docs/configuration.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration
2
+
3
+ ## Config System
4
+
5
+ Configuration follows a three-level hierarchy:
6
+
7
+ 1. **CLI defaults** β€” Argparse defaults in `run_pirads.py`, `run_cspca.py`, etc.
8
+ 2. **YAML overrides** β€” Values from `--config <file>.yaml` override CLI defaults
9
+ 3. **SLURM job name** β€” If `SLURM_JOB_NAME` is set, it overrides `run_name`
10
+
11
+ ```bash
12
+ # CLI defaults are overridden by YAML config
13
+ python run_pirads.py --mode train --config config/config_pirads_train.yaml
14
+ ```
15
+
16
+ !!! note
17
+ YAML values **always override** CLI defaults for any key present in the YAML file (`args.__dict__.update(config)`). To override a YAML value, edit the YAML file or omit the key from YAML so the CLI default is used.
18
+
19
+ ## PI-RADS Training Parameters
20
+
21
+ | Parameter | Default | Description |
22
+ |-----------|---------|-------------|
23
+ | `mode` | β€” | `train` or `test` (required) |
24
+ | `config` | β€” | Path to YAML config file |
25
+ | `data_root` | β€” | Root folder of images |
26
+ | `dataset_json` | β€” | Path to dataset JSON file |
27
+ | `num_classes` | `4` | Number of output classes (PI-RADS 2–5) |
28
+ | `mil_mode` | `att_trans` | MIL algorithm (`mean`, `max`, `att`, `att_trans`, `att_trans_pyramid`) |
29
+ | `tile_count` | `24` | Number of patches per scan |
30
+ | `tile_size` | `64` | Patch spatial size in pixels |
31
+ | `depth` | `3` | Number of slices per patch |
32
+ | `use_heatmap` | `True` | Enable heatmap-guided patch sampling |
33
+ | `workers` | `2` | DataLoader workers |
34
+ | `checkpoint` | `None` | Path to resume from checkpoint |
35
+ | `epochs` | `50` | Max training epochs |
36
+ | `early_stop` | `40` | Epochs without improvement before stopping |
37
+ | `batch_size` | `4` | Scans per batch |
38
+ | `optim_lr` | `3e-5` | Base learning rate |
39
+ | `weight_decay` | `0` | Optimizer weight decay |
40
+ | `amp` | `False` | Enable automatic mixed precision |
41
+ | `val_every` | `1` | Validation frequency (epochs) |
42
+ | `wandb` | `False` | Enable Weights & Biases logging |
43
+ | `project_name` | `Classification_prostate` | W&B project name |
44
+ | `run_name` | `train_pirads` | Run name for logging |
45
+ | `dry_run` | `False` | Quick test mode |
46
+
47
+ ## csPCa Training Parameters
48
+
49
+ | Parameter | Default | Description |
50
+ |-----------|---------|-------------|
51
+ | `mode` | β€” | `train` or `test` (required) |
52
+ | `config` | β€” | Path to YAML config file |
53
+ | `data_root` | β€” | Root folder of images |
54
+ | `dataset_json` | β€” | Path to dataset JSON file |
55
+ | `num_classes` | `4` | PI-RADS classes (for backbone initialization) |
56
+ | `mil_mode` | `att_trans` | MIL algorithm for backbone |
57
+ | `tile_count` | `24` | Number of patches per scan |
58
+ | `tile_size` | `64` | Patch spatial size |
59
+ | `depth` | `3` | Slices per patch |
60
+ | `use_heatmap` | `True` | Enable heatmap-guided patch sampling |
61
+ | `workers` | `2` | DataLoader workers |
62
+ | `checkpoint_pirads` | β€” | Path to pre-trained PI-RADS model (required for train) |
63
+ | `checkpoint_cspca` | β€” | Path to csPCa checkpoint (required for test) |
64
+ | `epochs` | `30` | Max training epochs |
65
+ | `batch_size` | `32` | Scans per batch |
66
+ | `optim_lr` | `2e-4` | Learning rate |
67
+ | `num_seeds` | `20` | Number of random seeds for CI |
68
+ | `val_every` | `1` | Validation frequency |
69
+ | `dry_run` | `False` | Quick test mode |
70
+
71
+ ## Preprocessing Parameters
72
+
73
+ | Parameter | Default | Description |
74
+ |-----------|---------|-------------|
75
+ | `config` | β€” | Path to YAML config file |
76
+ | `steps` | β€” | Steps to execute (required, one or more) |
77
+ | `t2_dir` | β€” | Directory of T2W images |
78
+ | `dwi_dir` | β€” | Directory of DWI images |
79
+ | `adc_dir` | β€” | Directory of ADC images |
80
+ | `seg_dir` | β€” | Directory of segmentation masks |
81
+ | `output_dir` | β€” | Output directory |
82
+ | `margin` | `0.2` | Center-crop margin fraction |
83
+ | `project_dir` | β€” | Project root (for reference images and models) |
84
+
85
+ ## Example YAML
86
+
87
+ === "PI-RADS Training"
88
+
89
+ ```yaml
90
+ data_root: /path/to/registered/t2_hist_matched
91
+ dataset_json: /path/to/PI-RADS_data.json
92
+ num_classes: 4
93
+ mil_mode: att_trans
94
+ tile_count: 24
95
+ tile_size: 64
96
+ depth: 3
97
+ use_heatmap: true
98
+ workers: 4
99
+ epochs: 100
100
+ batch_size: 8
101
+ optim_lr: 2e-4
102
+ weight_decay: 1e-5
103
+ amp: true
104
+ wandb: true
105
+ ```
106
+
107
+ === "csPCa Training"
108
+
109
+ ```yaml
110
+ data_root: /path/to/registered/t2_hist_matched
111
+ dataset_json: /path/to/csPCa_data.json
112
+ num_classes: 4
113
+ mil_mode: att_trans
114
+ tile_count: 24
115
+ tile_size: 64
116
+ depth: 3
117
+ use_heatmap: true
118
+ workers: 6
119
+ checkpoint_pirads: /path/to/models/pirads.pt
120
+ epochs: 80
121
+ batch_size: 8
122
+ optim_lr: 2e-4
123
+ ```
124
+
125
+ === "Preprocessing"
126
+
127
+ ```yaml
128
+ t2_dir: /path/to/raw/t2
129
+ dwi_dir: /path/to/raw/dwi
130
+ adc_dir: /path/to/raw/adc
131
+ output_dir: /path/to/processed
132
+ project_dir: /path/to/WSAttention-Prostate
133
+ ```
134
+
135
+ ## Dry-Run Mode
136
+
137
+ The `--dry_run` flag configures a minimal run for quick testing:
138
+
139
+ - Epochs: 2
140
+ - Batch size: 2
141
+ - Workers: 0
142
+ - Seeds: 2
143
+ - W&B: disabled
144
+
145
+ ```bash
146
+ python run_pirads.py --mode train --config config/config_pirads_train.yaml --dry_run
147
+ ```
docs/contributing.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ ## Running Tests
4
+
5
+ ```bash
6
+ # Full test suite
7
+ pytest tests/
8
+
9
+ # Single test
10
+ pytest tests/test_run.py::test_run_pirads_training
11
+ ```
12
+
13
+ Tests run in `--dry_run` mode (2 epochs, batch_size=2, no W&B logging).
14
+
15
+ ## Linting
16
+
17
+ This project uses [Ruff](https://docs.astral.sh/ruff/) for linting and formatting:
18
+
19
+ ```bash
20
+ # Check for lint errors
21
+ ruff check .
22
+
23
+ # Auto-format code
24
+ ruff format .
25
+ ```
26
+
27
+ **Ruff configuration** (from `pyproject.toml`):
28
+
29
+ | Setting | Value |
30
+ |---------|-------|
31
+ | Line length | 100 |
32
+ | Quote style | Double quotes |
33
+ | Rules | E (errors), W (warnings) |
34
+ | Ignored | E501 (line too long) |
35
+
36
+ ## SLURM Job Scripts
37
+
38
+ Job scripts are in `job_scripts/` and are configured for GPU partitions:
39
+
40
+ ```bash
41
+ sbatch job_scripts/train_pirads.sh
42
+ sbatch job_scripts/train_cspca.sh
43
+ ```
44
+
45
+ Key SLURM settings used:
46
+
47
+ | Setting | Value |
48
+ |---------|-------|
49
+ | Partition | `gpu` |
50
+ | Memory | 128 GB |
51
+ | GPUs | 1 |
52
+ | Time limit | 48 hours |
53
+
54
+ !!! tip
55
+ The SLURM job name (`--job-name`) automatically becomes the `run_name`, which determines the log directory at `logs/<run_name>/`.
56
+
57
+ ## Project Conventions
58
+
59
+ - **Configs** are stored in `config/` as YAML files
60
+ - **Logs** are written to `logs/<run_name>/` including TensorBoard events and training logs
61
+ - **Models** are saved to `logs/<run_name>/` during training; best models are saved to `models/` for deployment
62
+ - **Cache** is stored at `logs/<run_name>/cache/` and cleaned up automatically after training
docs/getting-started.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Getting Started
2
+
3
+ ## Prerequisites
4
+
5
+ - Python 3.11+
6
+ - NVIDIA GPU recommended (CUDA-compatible)
7
+ - ~128 GB RAM for training (configurable via batch size)
8
+
9
+ ## Installation
10
+
11
+ ```bash
12
+ git clone https://github.com/ai-assisted-healthcare/WSAttention-Prostate.git
13
+ cd WSAttention-Prostate
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ ### External Git Dependencies
18
+
19
+ Two packages are installed directly from GitHub repositories:
20
+
21
+ | Package | Source | Purpose |
22
+ |---------|--------|---------|
23
+ | `AIAH_utility` | `ai-assisted-healthcare/AIAH_utility` | Healthcare imaging utilities |
24
+ | `grad-cam` | `jacobgil/pytorch-grad-cam` | Gradient-weighted class activation maps |
25
+
26
+ These are included in `requirements.txt` and install automatically.
27
+
28
+ ## Verify Installation
29
+
30
+ Run the test suite in dry-run mode:
31
+
32
+ ```bash
33
+ pytest tests/
34
+ ```
35
+
36
+ Tests use `--dry_run` mode internally (2 epochs, batch_size=2, no W&B).
37
+
38
+ ## Data Format
39
+
40
+ Input MRI scans should be in **NRRD** or **NIfTI** format with three modalities per patient:
41
+
42
+ - T2-weighted (T2W)
43
+ - Diffusion-weighted imaging (DWI)
44
+ - Apparent diffusion coefficient (ADC)
45
+
46
+ ### Dataset JSON Structure
47
+
48
+ The data pipeline uses MONAI's decathlon-format JSON:
49
+
50
+ ```json
51
+ {
52
+ "train": [
53
+ {
54
+ "image": "path/to/t2.nrrd",
55
+ "dwi": "path/to/dwi.nrrd",
56
+ "adc": "path/to/adc.nrrd",
57
+ "mask": "path/to/prostate_mask.nrrd",
58
+ "heatmap": "path/to/heatmap.nrrd",
59
+ "label": 0
60
+ }
61
+ ],
62
+ "test": [
63
+ ...
64
+ ]
65
+ }
66
+ ```
67
+
68
+ The `image` key points to the T2W image, which serves as the reference modality. Labels for PI-RADS are 0-indexed: label `0` = PI-RADS 2, label `3` = PI-RADS 5. For csPCa, labels are binary (0 or 1).
69
+
70
+ ## Project Structure
71
+
72
+ ```
73
+ WSAttention-Prostate/
74
+ β”œβ”€β”€ run_pirads.py # PI-RADS training/testing entry point
75
+ β”œβ”€β”€ run_cspca.py # csPCa training/testing entry point
76
+ β”œβ”€β”€ run_inference.py # Full inference pipeline
77
+ β”œβ”€β”€ preprocess_main.py # Preprocessing entry point
78
+ β”œβ”€β”€ config/ # YAML configuration files
79
+ β”‚ β”œβ”€β”€ config_pirads_train.yaml
80
+ β”‚ β”œβ”€β”€ config_pirads_test.yaml
81
+ β”‚ β”œβ”€β”€ config_cspca_train.yaml
82
+ β”‚ β”œβ”€β”€ config_cspca_test.yaml
83
+ β”‚ └── config_preprocess.yaml
84
+ β”œβ”€β”€ src/
85
+ β”‚ β”œβ”€β”€ model/
86
+ β”‚ β”‚ β”œβ”€β”€ MIL.py # MILModel_3D β€” core MIL architecture
87
+ β”‚ β”‚ └── csPCa_model.py # csPCa_Model + SimpleNN head
88
+ β”‚ β”œβ”€β”€ data/
89
+ β”‚ β”‚ β”œβ”€β”€ data_loader.py # MONAI data pipeline
90
+ β”‚ β”‚ └── custom_transforms.py
91
+ β”‚ β”œβ”€β”€ train/
92
+ β”‚ β”‚ β”œβ”€β”€ train_pirads.py # PI-RADS training loop
93
+ β”‚ β”‚ └── train_cspca.py # csPCa training loop
94
+ β”‚ β”œβ”€β”€ preprocessing/
95
+ β”‚ β”‚ β”œβ”€β”€ register_and_crop.py
96
+ β”‚ β”‚ β”œβ”€β”€ prostate_mask.py
97
+ β”‚ β”‚ β”œβ”€β”€ histogram_match.py
98
+ β”‚ β”‚ └── generate_heatmap.py
99
+ β”‚ └── utils.py
100
+ β”œβ”€β”€ job_scripts/ # SLURM job templates
101
+ β”œβ”€β”€ tests/
102
+ β”œβ”€β”€ dataset/ # Reference images for histogram matching
103
+ └── models/ # Pre-trained model checkpoints
104
+ ```
docs/index.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center; margin-bottom: 2em;">
2
+ <img src="assets/logo.svg" alt="WSAttention-Prostate Logo" width="240">
3
+ </div>
4
+
5
+ # WSAttention-Prostate
6
+
7
+ **Weakly-supervised attention-based 3D Multiple Instance Learning for prostate cancer risk prediction on multiparametric MRI.**
8
+
9
+ WSAttention-Prostate is a two-stage deep learning pipeline that predicts clinically significant prostate cancer (csPCa) risk from T2-weighted, DWI, and ADC MRI sequences. It uses 3D patch-based Multiple Instance Learning with transformer attention to first classify PI-RADS scores, then predict csPCa risk β€” all without requiring lesion-level annotations.
10
+
11
+ ## Key Features
12
+
13
+ - **Weakly-supervised attention** β€” Heatmap-guided patch sampling and cosine-similarity attention loss replace the need for voxel-level labels
14
+ - **3D Multiple Instance Learning** β€” Extracts volumetric patches from MRI scans and aggregates them via transformer + attention pooling
15
+ - **Two-stage pipeline** β€” Stage 1 trains a 4-class PI-RADS classifier; Stage 2 freezes its backbone and trains a binary csPCa head
16
+ - **Multi-seed confidence intervals** β€” Runs 20 random seeds and reports 95% CI on AUC, sensitivity, and specificity
17
+ - **End-to-end preprocessing** β€” Registration, segmentation, histogram matching, and heatmap generation in a single configurable pipeline
18
+
19
+ ## Pipeline Overview
20
+
21
+ ```mermaid
22
+ flowchart LR
23
+ A[Raw MRI\nT2 + DWI + ADC] --> B[Preprocessing]
24
+ B --> C[Stage 1:\nPI-RADS Classification]
25
+ C --> D[Stage 2:\ncsPCa Prediction]
26
+ D --> E[Risk Score\n+ Top-5 Patches]
27
+ ```
28
+
29
+ ## Quick Links
30
+
31
+ - [Getting Started](getting-started.md) β€” Installation and first run
32
+ - [Pipeline](pipeline.md) β€” Full walkthrough of preprocessing, training, and evaluation
33
+ - [Architecture](architecture.md) β€” Model design and tensor shapes
34
+ - [Configuration](configuration.md) β€” YAML config reference
35
+ - [Inference](inference.md) β€” Running predictions on new data
docs/inference.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inference
2
+
3
+ ## Full Pipeline
4
+
5
+ `run_inference.py` runs the complete pipeline: preprocessing followed by PI-RADS classification and csPCa risk prediction.
6
+
7
+ ```bash
8
+ python run_inference.py --config config/config_preprocess.yaml
9
+ ```
10
+
11
+ This script:
12
+
13
+ 1. Runs all four preprocessing steps (register, segment, histogram match, heatmap)
14
+ 2. Loads the PI-RADS model from `models/pirads.pt`
15
+ 3. Loads the csPCa model from `models/cspca_model.pth`
16
+ 4. For each scan: predicts PI-RADS score, csPCa risk probability, and identifies the top-5 most-attended patches
17
+
18
+ ### Required Model Files
19
+
20
+ Place these in the `models/` directory:
21
+
22
+ | File | Description |
23
+ |------|-------------|
24
+ | `pirads.pt` | Trained PI-RADS MIL model checkpoint |
25
+ | `cspca_model.pth` | Trained csPCa model checkpoint |
26
+ | `prostate_segmentation_model.pt` | Pre-trained prostate segmentation model |
27
+
28
+ ### Output Format
29
+
30
+ Results are saved to `<output_dir>/results.json`:
31
+
32
+ ```json
33
+ {
34
+ "patient_001.nrrd": {
35
+ "Predicted PIRAD Score": 4.0,
36
+ "csPCa risk": 0.8234,
37
+ "Top left coordinate of top 5 patches(x,y,z)": [
38
+ [32, 45, 7],
39
+ [28, 50, 7],
40
+ [35, 42, 8],
41
+ [30, 48, 6],
42
+ [33, 44, 8]
43
+ ]
44
+ }
45
+ }
46
+ ```
47
+
48
+ ### Label Mapping
49
+
50
+ PI-RADS predictions are 0-indexed internally and shifted by +2 for display:
51
+
52
+ | Internal Label | PI-RADS Score |
53
+ |---------------|---------------|
54
+ | 0 | PI-RADS 2 |
55
+ | 1 | PI-RADS 3 |
56
+ | 2 | PI-RADS 4 |
57
+ | 3 | PI-RADS 5 |
58
+
59
+ csPCa risk is a continuous probability in [0, 1].
60
+
61
+ ## Testing Individual Models
62
+
63
+ ### PI-RADS Testing
64
+
65
+ ```bash
66
+ python run_pirads.py --mode test \
67
+ --config config/config_pirads_test.yaml \
68
+ --checkpoint models/pirads.pt
69
+ ```
70
+
71
+ Reports Quadratic Weighted Kappa (QWK) across multiple seeds.
72
+
73
+ ### csPCa Testing
74
+
75
+ ```bash
76
+ python run_cspca.py --mode test \
77
+ --config config/config_cspca_test.yaml \
78
+ --checkpoint_cspca models/cspca_model.pth
79
+ ```
80
+
81
+ Reports AUC, sensitivity, and specificity with 95% confidence intervals across 20 seeds (default).
docs/pipeline.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pipeline
2
+
3
+ The full pipeline has three phases: preprocessing, PI-RADS training (Stage 1), and csPCa training (Stage 2).
4
+
5
+ ```mermaid
6
+ flowchart TD
7
+ subgraph Preprocessing
8
+ R[register_and_crop] --> S[get_segmentation_mask]
9
+ S --> H[histogram_match]
10
+ H --> G[get_heatmap]
11
+ end
12
+
13
+ subgraph Stage 1
14
+ P[PI-RADS Training\nCrossEntropy + Attention Loss]
15
+ end
16
+
17
+ subgraph Stage 2
18
+ C[csPCa Training\nFrozen Backbone + BCE Loss]
19
+ end
20
+
21
+ G --> P
22
+ P -->|frozen backbone| C
23
+ ```
24
+
25
+ ## Preprocessing
26
+
27
+ Run all four steps in sequence:
28
+
29
+ ```bash
30
+ python preprocess_main.py \
31
+ --config config/config_preprocess.yaml \
32
+ --steps register_and_crop get_segmentation_mask histogram_match get_heatmap
33
+ ```
34
+
35
+ ### Step 1: Register and Crop
36
+
37
+ Resamples T2, DWI, and ADC to a common spacing of `(0.4, 0.4, 3.0)` mm using `picai_prep`, then center-crops with a configurable margin (default 20%).
38
+
39
+ ### Step 2: Prostate Segmentation
40
+
41
+ Runs a pre-trained segmentation model on T2W images to generate binary prostate masks. Post-processing retains only the top 10 slices by non-zero voxel count.
42
+
43
+ ### Step 3: Histogram Matching
44
+
45
+ Matches the intensity histogram of each modality to a reference image within masked (prostate) regions using `skimage.exposure.match_histograms`.
46
+
47
+ ### Step 4: Heatmap Generation
48
+
49
+ Creates attention heatmaps from DWI and ADC:
50
+
51
+ - **DWI heatmap**: `(dwi - min) / (max - min)` β€” higher DWI signal = higher attention
52
+ - **ADC heatmap**: `(max - adc) / (max - min)` β€” lower ADC = higher attention (inverted)
53
+ - **Combined**: element-wise product, re-normalized to [0, 1]
54
+
55
+ !!! note "Step Dependencies"
56
+ Steps must run in the order shown above. The pipeline validates dependencies automatically β€” for example, `get_heatmap` requires `get_segmentation_mask` and `histogram_match` to have run first.
57
+
58
+ ## Stage 1: PI-RADS Classification
59
+
60
+ Trains a 4-class PI-RADS classifier (grades 2–5, mapped to labels 0–3).
61
+
62
+ ```bash
63
+ python run_pirads.py --mode train --config config/config_pirads_train.yaml
64
+ ```
65
+
66
+ **Training details:**
67
+
68
+ | Component | Value |
69
+ |-----------|-------|
70
+ | Loss | CrossEntropy + cosine-similarity attention loss |
71
+ | Attention loss weight | Linear warmup over 25 epochs to `lambda=2.0` |
72
+ | Optimizer | AdamW (base LR `3e-5`, transformer LR `6e-5`) |
73
+ | Scheduler | CosineAnnealingLR |
74
+ | Metric | Quadratic Weighted Kappa (QWK) |
75
+ | Early stopping | After 40 epochs without validation loss improvement |
76
+ | AMP | Disabled by default (enabled in example YAML config) |
77
+
78
+ **Attention loss**: For each batch, the model's learned attention weights are compared against heatmap-derived attention labels via cosine similarity. PI-RADS 2 samples receive uniform attention (no lesion expected). The loss is weighted by `lambda_att`, which warms up linearly over the first 25 epochs.
79
+
80
+ ## Stage 2: csPCa Risk Prediction
81
+
82
+ Builds on a frozen PI-RADS backbone to predict binary csPCa risk.
83
+
84
+ ```bash
85
+ python run_cspca.py --mode train --config config/config_cspca_train.yaml
86
+ ```
87
+
88
+ **Training details:**
89
+
90
+ | Component | Value |
91
+ |-----------|-------|
92
+ | Loss | Binary Cross-Entropy (BCE) |
93
+ | Backbone | Frozen PI-RADS model (ResNet18 + Transformer); attention module is trainable |
94
+ | Head | SimpleNN: `512 β†’ 256 β†’ 128 β†’ 1` with ReLU + Dropout(0.3) + Sigmoid |
95
+ | Optimizer | AdamW (LR `2e-4`) |
96
+ | Seeds | 20 random seeds (default) for 95% CI |
97
+ | Metrics | AUC, Sensitivity, Specificity |
98
+
99
+ The backbone's feature extractor (`net`), transformer, and `myfc` are frozen. The attention module and `SimpleNN` classification head are trained. After training across all seeds, the framework reports mean and 95% confidence intervals for AUC, sensitivity, and specificity.
mkdocs.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ site_name: WSAttention-Prostate
2
+ site_description: Weakly-supervised attention-based 3D MIL for prostate cancer risk prediction on multiparametric MRI
3
+ repo_url: https://github.com/ai-assisted-healthcare/WSAttention-Prostate
4
+ repo_name: WSAttention-Prostate
5
+
6
+ theme:
7
+ name: material
8
+ logo: assets/logo.svg
9
+ favicon: assets/logo.svg
10
+ palette:
11
+ - scheme: default
12
+ primary: indigo
13
+ accent: teal
14
+ toggle:
15
+ icon: material/brightness-7
16
+ name: Switch to dark mode
17
+ - scheme: slate
18
+ primary: indigo
19
+ accent: teal
20
+ toggle:
21
+ icon: material/brightness-4
22
+ name: Switch to light mode
23
+ font:
24
+ text: Inter
25
+ code: JetBrains Mono
26
+ features:
27
+ - navigation.instant
28
+ - navigation.sections
29
+ - navigation.top
30
+ - search.suggest
31
+ - search.highlight
32
+ - content.code.copy
33
+ - content.tabs.link
34
+
35
+ plugins:
36
+ - search
37
+ # TODO: add mkdocstrings[python] plugin for autodoc support
38
+
39
+ markdown_extensions:
40
+ - admonition
41
+ - pymdownx.details
42
+ - pymdownx.superfences:
43
+ custom_fences:
44
+ - name: mermaid
45
+ class: mermaid
46
+ format: !!python/name:pymdownx.superfences.fence_code_format
47
+ - pymdownx.highlight:
48
+ anchor_linenums: true
49
+ - pymdownx.inlinehilite
50
+ - pymdownx.tabbed:
51
+ alternate_style: true
52
+ - tables
53
+ - attr_list
54
+ - md_in_html
55
+
56
+ nav:
57
+ - Home: index.md
58
+ - Getting Started: getting-started.md
59
+ - Pipeline: pipeline.md
60
+ - Architecture: architecture.md
61
+ - Configuration: configuration.md
62
+ - Inference: inference.md
63
+ - API Reference:
64
+ - Models: api/models.md
65
+ - Preprocessing: api/preprocessing.md
66
+ - Data Loading: api/data.md
67
+ - Contributing: contributing.md