--- title: Prostate Inference emoji: πŸš€ colorFrom: blue colorTo: red sdk: docker pinned: false short_description: Predicts csPCa risk and PI-RADS score from bpMRI sequences app_port: 8501 ---

WSAttention-Prostate Logo

Hugging Face Spaces CI/CD Status Python 3.9 PyTorch 2.5 Docker License

# Weakly Supervised Attention-Based Deep Learning for Prostate Cancer Characterization from Bi-Parametric Prostate MRI. Predicts PI-RADS score and risk of clinically significant prostate cancer (csPCa) from T2-Weighted (T2W), Diffusion Weighted Imaging (DWI) and Apparent Diffusion Coefficient (ADC) sequences of bi-paramteric MRI (bpMRI). ## πŸš€ Platform Access Real-time inference via [GUI](https://huggingface.co/spaces/anirudh0410/Prostate-Inference) ## ⭐ Abstract Deep learning methods used in medical AIβ€”particularly for csPCa prediction and PI-RADS classificationβ€”typically rely on expert-annotated labels for training, which limits scalability to larger datasets and broader clinical adoption. To address this, we employ a two-stage multiple-instance learning (MIL) framework pretrained on scan-level PI-RADS annotations with attention-based weak supervision, guided by weak attention heatmaps automatically derived from ADC and DWI sequences. For downstream risk assessment, the PI-RADS classification head is replaced and fine-tuned on a substantially smaller dataset to predict csPCa risk. Careful preprocessing is applied to mitigate variability arising from cross-site MRI acquisition differences. For further details, please refer to our paper or visit the project website. ## Key Features - ⚑ **Automatic Attention Heatmaps** - Weak attention heatmaps generated automatically from DWI and ADC sequnces. - 🧠 **Weakly-Supervised Attention** β€” Heatmap-guided patch sampling and cosine-similarity attention loss, replace the need for voxel-level labels. - 🧩 **3D Multiple Instance Learning** β€” Extracts volumetric patches from bpMRI scans and aggregates them via transformer + attention pooling. - πŸ‘οΈ **Explainable** β€” Visualise salient patches highlighting probable tumour regions. - 🧹 **Preprocessing** β€” Preprocessing to minimize inter-center MRI acquisiton variability. - πŸ₯ **End-to-end Pipeline** β€” Open source, clinically viable complete pipeline. ## πŸš€ Quick Start ### 1. Clone and Setup ```bash git clone https://github.com/anirudhbalaraman/WSAttention-Prostate.git cd WSAttention-Prostate pip install -r requirements.txt pytest tests/ ``` ### 2. Model Download ```bash mkdir -p ./models curl -L -o models/file1.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/cspca_model.pth curl -L -o models/file2.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/pirads.pt curl -L -o models/file3.pth https://huggingface.co/anirudh0410/WSAttention-Prostate/resolve/main/prostate_segmentation_model.pt ``` ## πŸš€ Usage ### 🩺 Inference ```bash python run_inference.py --config config/config_preprocess.yaml ``` Run run_inference.py to execute the full pipeline, from preprocessing to model predictions. - πŸ“‚ **Input arguments:** - *t2_dir, dwi_dir, adc_dir*: Path to T2W, DWI and ADC sequnces respectively. - *output_dir*: Path to store preprocessed files and results. ⚠️ ***NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.*** - πŸ“Š **Outputs:** The following are stored for each scan: - Risk of csPCa. - PI-RADS score. - Coordinaates of top 5 salient patches. The results are stored in `results.json` saved in output_dir along with the intermediary files from pre processing including the prostate segmentation mask. The patches can be visualised using `visualisation.ipynb` ### 🧹 Preprocessing Execute preprocess_main.py to preprocess your MRI files. ⚠️ ***NOTE: For each scan, all sequences should share the same filename, and the input files must be in NRRD format.*** ```bash python preprocess_main.py \ --steps register_and_crop get_segmentation_mask histogram_match get_heatmap \ --config config/config_preprocess.yaml ``` ### βš™οΈ PI-RADS ans csPCa Model Training - **Input Arguments:** - *dataset_json*: File paths of the scans. JSON used for training: `dataset/PI-RADS_data.json`. - *data_root*: Root direcotry of T2W files. - *tile_count*: No. of patches per scan. - *tile_size*: Length and width of each patch. - *depth*: Depth of each 3D patch ⚠️ ***NOTE: run_cspca.py in train mode requires PI-RADS MIL backbone.*** ```bash python run_pirads.py --mode train --config config/config_pirads_train.yaml python run_cspca.py --mode train --config config/config_cspca_train.yaml ``` ### πŸ“Š Testing ```bash python run_pirads.py --mode test --config config/config_pirads_test.yaml python run_cspca.py --mode test --config config/config_cspca_test.yaml ``` See the [full documentation](https://anirudhbalaraman.github.io/WSAttention-Prostate/) for detailed configuration options and data format requirements. ## Project Structure ``` WSAttention-Prostate/ β”œβ”€β”€ run_pirads.py # PI-RADS training/testing entry point β”œβ”€β”€ run_cspca.py # csPCa training/testing entry point β”œβ”€β”€ run_inference.py # Full inference pipeline β”œβ”€β”€ preprocess_main.py # Preprocessing entry point β”œβ”€β”€ config/ # YAML configuration files β”œβ”€β”€ src/ β”‚ β”œβ”€β”€ model/ β”‚ β”‚ β”œβ”€β”€ MIL.py # MILModel_3D β€” core MIL architecture, PI-RADS model β”‚ β”‚ └── csPCa_model.py # csPCa_Model β”‚ β”œβ”€β”€ data/ β”‚ β”‚ β”œβ”€β”€ data_loader.py # MONAI data pipeline β”‚ β”‚ └── custom_transforms.py # Custom MONAI transforms β”‚ β”œβ”€β”€ train/ β”‚ β”‚ β”œβ”€β”€ train_pirads.py # PI-RADS training loop β”‚ β”‚ └── train_cspca.py # csPCa training loop β”‚ β”œβ”€β”€ preprocessing/ # Registration, segmentation, histogram matching, heatmaps β”‚ └── utils.py # Shared utilities β”œβ”€β”€ tests/ β”œβ”€β”€ dataset/ # Reference images for histogram matching └── models/ # Downloaded checkpoints (not in repo) ``` ## πŸ™ Acknowledgement This work was in large parts funded by the Wilhelm Sander Foundation. Funded by the European Union. Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or European Health and Digital Executive Agency (HADEA). Neither the European Union nor the granting authority can be held responsible for them.