Create README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,72 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
library_name: pytorch
|
| 5 |
+
tags:
|
| 6 |
+
- medical-imaging
|
| 7 |
+
- mri
|
| 8 |
+
- 3d-resnet
|
| 9 |
+
- monai
|
| 10 |
+
- oncology
|
| 11 |
+
- breast-cancer
|
| 12 |
+
- sustainable-ai
|
| 13 |
+
- explainable-ai
|
| 14 |
+
metrics:
|
| 15 |
+
- roc_auc
|
| 16 |
+
pipeline_tag: image-classification
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Advanced Multi-Stream 3D-ResNet for ODELIA (Phase 3)
|
| 20 |
+
|
| 21 |
+
This repository contains the "Better Model" (Phase 3) for breast cancer classification on multi-parametric MRI. This version represents the most stable and optimized iteration, featuring a hybrid architecture designed to handle high-resolution 3D sequences with extreme memory efficiency.
|
| 22 |
+
|
| 23 |
+
## 🛠 Key Engineering Breakthroughs (Phase 3)
|
| 24 |
+
- **Instance Normalization Integration:** Replaced Batch Normalization with **InstanceNorm 3D**. This critical fix allows the model to maintain stable gradients even with a Batch Size of 1, preventing the collapse observed in earlier versions.
|
| 25 |
+
- **Dynamic Temporal Fusion:** Uses a Multi-Stream approach where the Kinetic stream processes a variable number of post-contrast phases.
|
| 26 |
+
- **Improved Temporal Aggregation:** Transitioned to **Temporal Mean Pooling** for feature fusion, providing a more robust "global temporal summary" compared to complex recurrent layers (LSTM/RNN) on this specific dataset volume.
|
| 27 |
+
- **Explainable AI (XAI):** Full support for **3D Grad-CAM** visualization on the T2 structural stream.
|
| 28 |
+
|
| 29 |
+
## Model Description
|
| 30 |
+
- **Developed by:** [Your Name/Pseudo]
|
| 31 |
+
- **Architecture:** Multi-Stream 3D ResNet-18 (Custom InstanceNorm version).
|
| 32 |
+
- **Task:** 3 Multi-class classification (Normal, Benign, Malignant).
|
| 33 |
+
- **Input Modalities:** 1. **T2 Stream:** Structural 3D MRI volume.
|
| 34 |
+
2. **Kinetics Stream:** Temporal sequence of subtraction volumes (Post_i - Pre).
|
| 35 |
+
- **Spatial Resolution:** 128 x 128 x 64.
|
| 36 |
+
|
| 37 |
+
## Technical Details
|
| 38 |
+
|
| 39 |
+
### Architecture Components
|
| 40 |
+
- **Encoders:** Two parallel 3D ResNet-18 backbones with `InstanceNorm3D`.
|
| 41 |
+
- **Temporal Handling:** The kinetics encoder processes each time point independently; features are then aggregated via mean pooling across the time dimension.
|
| 42 |
+
- **Classification Head:** A 2-layer MLP with **40% Dropout** to prevent overfitting on complex features.
|
| 43 |
+
|
| 44 |
+
### Training Protocol
|
| 45 |
+
- **Optimizer:** AdamW (LR = 2e-4, Weight Decay = 1e-4).
|
| 46 |
+
- **Scheduler:** CosineAnnealingLR (T_max=50) for smooth convergence.
|
| 47 |
+
- **Loss:** **Focal Loss (gamma=2.0)** with dynamic class weighting to prioritize the rare "Benign" and "Malignant" cases.
|
| 48 |
+
- **Optimization:** Mixed Precision (AMP) and Gradient Accumulation (steps=4).
|
| 49 |
+
|
| 50 |
+
## Performance & Sustainability
|
| 51 |
+
The model was trained on the **IDUN HPC Cluster**.
|
| 52 |
+
- **Carbon Tracking:** Training included environmental footprint monitoring (~0.25 kW/GPU on a low-carbon Norwegian grid).
|
| 53 |
+
- **Stability:** Evaluation shows consistent convergence across 5 folds, validated by a rigorous audit of the patient split.
|
| 54 |
+
|
| 55 |
+
## How to Load the Model
|
| 56 |
+
This model requires the `MultiStreamResNet` class from the `utils.py` file.
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
import torch
|
| 60 |
+
from utils import MultiStreamResNet
|
| 61 |
+
|
| 62 |
+
# Initialize model with Phase 3 dimensions
|
| 63 |
+
model = MultiStreamResNet(num_classes=3, hidden_dim=256)
|
| 64 |
+
|
| 65 |
+
# Load weights for a specific fold
|
| 66 |
+
model.load_state_dict(torch.load("best_resnet_odelia_fold0.pth", map_location="cpu"))
|
| 67 |
+
model.eval()
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Explainability (Grad-CAM)
|
| 71 |
+
The model includes a wrapper for Explainable AI. It targets the layer4 of the T2 encoder to generate 3D heatmaps,
|
| 72 |
+
helping to identify the specific anatomical regions influencing the "Malignant" prediction.
|