File size: 7,548 Bytes
9d44094
 
 
 
 
 
 
 
 
 
cd89698
282a9ed
cd89698
 
 
5d668a6
0b653cf
 
67207d1
 
0b653cf
6c37b00
0b653cf
 
cd89698
 
 
23c0f64
7ef3f0b
5d75d65
0ae0dfe
5d668a6
cd89698
23c0f64
d994c96
8a25792
cd89698
 
 
b1824d5
 
 
070a2e9
b1824d5
 
cd89698
 
b1824d5
c0334de
cd89698
282a9ed
cd89698
 
 
 
c0334de
7c32841
4f13c99
7c32841
 
 
 
4f13c99
cd89698
b1824d5
070a2e9
9cc4997
 
 
c0334de
9cc4997
070a2e9
 
 
63647f2
402190d
9cc4997
63647f2
9cc4997
070a2e9
 
 
 
c0334de
 
070a2e9
c0334de
070a2e9
 
cd89698
c0334de
 
 
cd89698
 
c0334de
402190d
 
822a4b7
 
 
 
 
4f116d9
402190d
cd89698
 
 
 
 
402190d
070a2e9
cd89698
 
c0334de
 
cd89698
 
282a9ed
cd89698
 
 
 
 
c1c1134
 
 
 
 
cd89698
 
c1c1134
 
cd89698
c1c1134
070a2e9
cd89698
c1c1134
 
 
 
cd89698
c1c1134
 
cd89698
 
070a2e9
 
cd89698
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
---
title: Prostate Inference
emoji: πŸš€
colorFrom: blue
colorTo: red
sdk: docker
pinned: false
short_description: Predicts csPCa risk and PI-RADS score from bpMRI sequences
app_port: 8501
---
<p align="center">
  <img src="docs/assets/logo.svg" alt="WSAttention-Prostate Logo" width="560">
</p>

<p align="center">
  <a href="https://huggingface.co/spaces/anirudh0410/Prostate-Inference">
    <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue" alt="Hugging Face Spaces">
  </a>
<a href="https://github.com/anirudhbalaraman/WSAttention-Prostate/actions/workflows/ci.yaml">
    <img src="https://github.com/anirudhbalaraman/WSAttention-Prostate/actions/workflows/ci.yaml/badge.svg" alt="CI/CD Status">
  </a>
  <img src="https://img.shields.io/badge/python-3.9-blue?logo=python&logoColor=white" alt="Python 3.9">
  <img src="https://img.shields.io/badge/pytorch-2.5-ee4c2c?logo=pytorch&logoColor=white" alt="PyTorch 2.5">  
  <img src="https://img.shields.io/badge/docker-automated-blue?logo=docker&logoColor=white" alt="Docker">
  <img src="https://img.shields.io/badge/license-Apache%202.0-green" alt="License">
</p>

# Weakly Supervised Attention-Based Deep Learning for Prostate Cancer Characterization from Bi-Parametric Prostate MRI.
Predicts PI-RADS score and risk of clinically significant prostate cancer (csPCa) from T2-Weighted (T2W), Diffusion Weighted Imaging (DWI) and Apparent Diffusion Coefficient (ADC) sequences of bi-paramteric MRI (bpMRI).

## πŸš€ Platform Access
Real-time inference via [GUI](https://huggingface.co/spaces/anirudh0410/Prostate-Inference)

## ⭐ Abstract

Deep learning methods used in medical AIβ€”particularly for csPCa prediction and PI-RADS classificationβ€”typically rely on expert-annotated labels for training, which limits scalability to larger datasets and broader clinical adoption. To address this, we employ a two-stage multiple-instance learning (MIL) framework pretrained on scan-level PI-RADS annotations with attention-based weak supervision, guided by weak attention heatmaps automatically derived from ADC and DWI sequences. For downstream risk assessment, the PI-RADS classification head is replaced and fine-tuned on a substantially smaller dataset to predict csPCa risk. Careful preprocessing is applied to mitigate variability arising from cross-site MRI acquisition differences. For further details, please refer to our paper or visit the project website.

## Key Features

- ⚑ **Automatic Attention Heatmaps** - Weak attention heatmaps generated automatically from DWI and ADC sequnces.
- 🧠 **Weakly-Supervised Attention** β€” Heatmap-guided patch sampling and cosine-similarity attention loss, replace the need for voxel-level labels.
- 🧩 **3D Multiple Instance Learning** β€” Extracts volumetric patches from bpMRI scans and aggregates them via transformer + attention pooling.
- πŸ‘οΈ **Explainable** β€” Visualise salient patches highlighting probable tumour regions.
- 🧹 **Preprocessing** β€” Preprocessing to minimize inter-center MRI acquisiton variability.
- πŸ₯ **End-to-end Pipeline** β€” Open source, clinically viable complete pipeline. 


## πŸš€ Quick Start
### 1. Clone and Setup
```bash
git clone https://github.com/anirudhbalaraman/WSAttention-Prostate.git
cd WSAttention-Prostate
pip install -r requirements.txt
pytest tests/
```
### 2. Model Download

```bash
mkdir -p ./models
curl -L -o models/file1.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/cspca_model.pth
curl -L -o models/file2.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/pirads.pt
curl -L -o models/file3.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/prostate_segmentation_model.pt
```

## πŸš€ Usage
### 🩺 Inference
```bash
python run_inference.py --config config/config_preprocess.yaml
```

Run run_inference.py to execute the full pipeline, from preprocessing to model predictions. 
- πŸ“‚ **Input arguments:**
  -  *t2_dir, dwi_dir, adc_dir*: Path to T2W, DWI and ADC sequnces respectively.
  -  *output_dir*: Path to store preprocessed files and results.
 
  ⚠️ ***NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.***

- πŸ“Š **Outputs:**
  The following are stored for each scan:
  -  Risk of csPCa.
  -  PI-RADS score.
  -  Coordinaates of top 5 salient patches. 
The results are stored in `results.json` saved in output_dir along with the intermediary files from pre processing including the prostate segmentation mask. The patches can be visualised using `visualisation.ipynb`


### 🧹 Preprocessing

Execute preprocess_main.py to preprocess your MRI files.
⚠️ ***NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.***
```bash
python preprocess_main.py \
  --steps register_and_crop get_segmentation_mask histogram_match get_heatmap \
  --config config/config_preprocess.yaml
```


### βš™οΈ PI-RADS ans csPCa Model Training
- **Input Arguments:**
  - *dataset_json*: File paths of the scans. JSON used for training: `dataset/PI-RADS_data.json`.
  - *data_root*: Root direcotry of T2W files.
  - *tile_count*: No. of patches per scan.
  - *tile_size*: Length and width of each patch.
  - *depth*: Depth of each 3D patch

⚠️ ***NOTE: run_cspca.py in train mode requires PI-RADS MIL backbone.***

```bash
python run_pirads.py --mode train --config config/config_pirads_train.yaml
python run_cspca.py --mode train --config config/config_cspca_train.yaml
```

### πŸ“Š Testing

```bash
python run_pirads.py --mode test --config config/config_pirads_test.yaml
python run_cspca.py --mode test --config config/config_cspca_test.yaml
```

See the [full documentation](https://anirudhbalaraman.github.io/WSAttention-Prostate/) for detailed configuration options and data format requirements.

## Project Structure

```
WSAttention-Prostate/
β”œβ”€β”€ run_pirads.py                # PI-RADS training/testing entry point
β”œβ”€β”€ run_cspca.py                 # csPCa training/testing entry point
β”œβ”€β”€ run_inference.py             # Full inference pipeline
β”œβ”€β”€ preprocess_main.py           # Preprocessing entry point
β”œβ”€β”€ config/                      # YAML configuration files
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”œβ”€β”€ MIL.py               # MILModel_3D β€” core MIL architecture, PI-RADS model
β”‚   β”‚   └── csPCa_model.py       # csPCa_Model
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ data_loader.py       # MONAI data pipeline
β”‚   β”‚   └── custom_transforms.py # Custom MONAI transforms 
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ train_pirads.py      # PI-RADS training loop
β”‚   β”‚   └── train_cspca.py       # csPCa training loop
β”‚   β”œβ”€β”€ preprocessing/           # Registration, segmentation, histogram matching, heatmaps
β”‚   └── utils.py                 # Shared utilities
β”œβ”€β”€ tests/
β”œβ”€β”€ dataset/                     # Reference images for histogram matching
└── models/                      # Downloaded checkpoints (not in repo)
```

## πŸ™ Acknowledgement
This work was in large parts funded by the Wilhelm Sander Foundation. Funded by the European Union. Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or European Health and Digital Executive Agency (HADEA). Neither the European Union nor the granting authority can be held responsible for them.