Prostate-Inference / README.md
GitHub Action
Add HF metadata for deployment
9d44094
metadata
title: Prostate Inference
emoji: πŸš€
colorFrom: blue
colorTo: red
sdk: docker
pinned: false
short_description: Predicts csPCa risk and PI-RADS score from bpMRI sequences
app_port: 8501

WSAttention-Prostate Logo

Hugging Face Spaces CI/CD Status Python 3.9 PyTorch 2.5 Docker License

Weakly Supervised Attention-Based Deep Learning for Prostate Cancer Characterization from Bi-Parametric Prostate MRI.

Predicts PI-RADS score and risk of clinically significant prostate cancer (csPCa) from T2-Weighted (T2W), Diffusion Weighted Imaging (DWI) and Apparent Diffusion Coefficient (ADC) sequences of bi-paramteric MRI (bpMRI).

πŸš€ Platform Access

Real-time inference via GUI

⭐ Abstract

Deep learning methods used in medical AIβ€”particularly for csPCa prediction and PI-RADS classificationβ€”typically rely on expert-annotated labels for training, which limits scalability to larger datasets and broader clinical adoption. To address this, we employ a two-stage multiple-instance learning (MIL) framework pretrained on scan-level PI-RADS annotations with attention-based weak supervision, guided by weak attention heatmaps automatically derived from ADC and DWI sequences. For downstream risk assessment, the PI-RADS classification head is replaced and fine-tuned on a substantially smaller dataset to predict csPCa risk. Careful preprocessing is applied to mitigate variability arising from cross-site MRI acquisition differences. For further details, please refer to our paper or visit the project website.

Key Features

  • ⚑ Automatic Attention Heatmaps - Weak attention heatmaps generated automatically from DWI and ADC sequnces.
  • 🧠 Weakly-Supervised Attention β€” Heatmap-guided patch sampling and cosine-similarity attention loss, replace the need for voxel-level labels.
  • 🧩 3D Multiple Instance Learning β€” Extracts volumetric patches from bpMRI scans and aggregates them via transformer + attention pooling.
  • πŸ‘οΈ Explainable β€” Visualise salient patches highlighting probable tumour regions.
  • 🧹 Preprocessing β€” Preprocessing to minimize inter-center MRI acquisiton variability.
  • πŸ₯ End-to-end Pipeline β€” Open source, clinically viable complete pipeline.

πŸš€ Quick Start

1. Clone and Setup

git clone https://github.com/anirudhbalaraman/WSAttention-Prostate.git
cd WSAttention-Prostate
pip install -r requirements.txt
pytest tests/

2. Model Download

mkdir -p ./models
curl -L -o models/file1.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/cspca_model.pth
curl -L -o models/file2.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/pirads.pt
curl -L -o models/file3.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/prostate_segmentation_model.pt

πŸš€ Usage

🩺 Inference

python run_inference.py --config config/config_preprocess.yaml

Run run_inference.py to execute the full pipeline, from preprocessing to model predictions.

  • πŸ“‚ Input arguments:

    • t2_dir, dwi_dir, adc_dir: Path to T2W, DWI and ADC sequnces respectively.
    • output_dir: Path to store preprocessed files and results.

    ⚠️ NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.

  • πŸ“Š Outputs: The following are stored for each scan:

    • Risk of csPCa.
    • PI-RADS score.
    • Coordinaates of top 5 salient patches. The results are stored in results.json saved in output_dir along with the intermediary files from pre processing including the prostate segmentation mask. The patches can be visualised using visualisation.ipynb

🧹 Preprocessing

Execute preprocess_main.py to preprocess your MRI files. ⚠️ NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.

python preprocess_main.py \
  --steps register_and_crop get_segmentation_mask histogram_match get_heatmap \
  --config config/config_preprocess.yaml

βš™οΈ PI-RADS ans csPCa Model Training

  • Input Arguments:
    • dataset_json: File paths of the scans. JSON used for training: dataset/PI-RADS_data.json.
    • data_root: Root direcotry of T2W files.
    • tile_count: No. of patches per scan.
    • tile_size: Length and width of each patch.
    • depth: Depth of each 3D patch

⚠️ NOTE: run_cspca.py in train mode requires PI-RADS MIL backbone.

python run_pirads.py --mode train --config config/config_pirads_train.yaml
python run_cspca.py --mode train --config config/config_cspca_train.yaml

πŸ“Š Testing

python run_pirads.py --mode test --config config/config_pirads_test.yaml
python run_cspca.py --mode test --config config/config_cspca_test.yaml

See the full documentation for detailed configuration options and data format requirements.

Project Structure

WSAttention-Prostate/
β”œβ”€β”€ run_pirads.py                # PI-RADS training/testing entry point
β”œβ”€β”€ run_cspca.py                 # csPCa training/testing entry point
β”œβ”€β”€ run_inference.py             # Full inference pipeline
β”œβ”€β”€ preprocess_main.py           # Preprocessing entry point
β”œβ”€β”€ config/                      # YAML configuration files
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”œβ”€β”€ MIL.py               # MILModel_3D β€” core MIL architecture, PI-RADS model
β”‚   β”‚   └── csPCa_model.py       # csPCa_Model
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ data_loader.py       # MONAI data pipeline
β”‚   β”‚   └── custom_transforms.py # Custom MONAI transforms 
β”‚   β”œβ”€β”€ train/
β”‚   β”‚   β”œβ”€β”€ train_pirads.py      # PI-RADS training loop
β”‚   β”‚   └── train_cspca.py       # csPCa training loop
β”‚   β”œβ”€β”€ preprocessing/           # Registration, segmentation, histogram matching, heatmaps
β”‚   └── utils.py                 # Shared utilities
β”œβ”€β”€ tests/
β”œβ”€β”€ dataset/                     # Reference images for histogram matching
└── models/                      # Downloaded checkpoints (not in repo)

πŸ™ Acknowledgement

This work was in large parts funded by the Wilhelm Sander Foundation. Funded by the European Union. Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or European Health and Digital Executive Agency (HADEA). Neither the European Union nor the granting authority can be held responsible for them.