| --- |
| language: en |
| license: mit |
| library_name: pytorch |
| tags: |
| - medical-imaging |
| - mri |
| - 3d-resnet |
| - monai |
| - oncology |
| - breast-cancera |
| - sustainable-ai |
| - explainable-ai |
| metrics: |
| - roc_auc |
| pipeline_tag: image-classification |
| --- |
| |
| # Advanced Multi-Stream 3D-ResNet for ODELIA (Phase 3) |
|
|
| This repository contains the "Better Model" (Phase 3) for breast cancer classification on multi-parametric MRI. This version represents the most stable and optimized iteration, featuring a hybrid architecture designed to handle high-resolution 3D sequences with extreme memory efficiency. |
|
|
| ## Key Engineering Breakthroughs (Phase 3) |
| - **Instance Normalization Integration:** Replaced Batch Normalization with **InstanceNorm 3D**. This critical fix allows the model to maintain stable gradients even with a Batch Size of 1, preventing the collapse observed in earlier versions. |
| - **Dynamic Temporal Fusion:** Uses a Multi-Stream approach where the Kinetic stream processes a variable number of post-contrast phases. |
| - **Improved Temporal Aggregation:** Selected Temporal Mean Pooling over Recurrent units (LSTM/RNN) to ensure a more robust global temporal summary and prevent overfitting on this specific dataset size. |
| - **Explainable AI (XAI):** Full support for **3D Grad-CAM** visualization on the T2 structural stream. |
| - |
| ## Mathematical & Architectural Refinement |
| - **From BatchNorm to InstanceNorm:** Early versions (V3) suffered from performance collapse (AUROC ~0.55) due to the conflict between a batch size of 1 and Batch Normalization layers. Switching to **Instance Normalization** resolved this by stabilizing feature maps regardless of batch size. |
| - **Robust Temporal Aggregation:** To prevent the massive overfitting observed with LSTM units (~0.0003 training loss but poor validation), we implemented **Temporal Mean Pooling**. This non-learnable layer effectively summarizes the enhancement kinetics without the parameter overhead of recurrent networks. |
|
|
| ## Advanced Pre-processing & Augmentation |
| - **Automatic ROI Cropping:** Uses `CropForegroundd` to focus 100% of the spatial resolution on breast tissue, removing irrelevant thoracic signals. |
| - **Modality Dropout (20%):** A custom transform that randomly masks T2 or Kinetic phases during training, preventing co-adaptation and forcing the model to learn multi-modal features. |
| - **Focal Loss Strategy:** Implemented to address class imbalance, specifically targeting the "Benign" class detection failure observed in baseline models. |
|
|
| ## Model Description |
| - **Developed by:** THOUAN Simon |
| - **Architecture:** Multi-Stream 3D ResNet-18 (Custom InstanceNorm version). |
| - **Task:** 3 Multi-class classification (Normal, Benign, Malignant). |
| - **Input Modalities:** 1. **T2 Stream:** Structural 3D MRI volume. |
| 2. **Kinetics Stream:** Temporal sequence of subtraction volumes (Post_i - Pre). |
| - **Spatial Resolution:** 128 x 128 x 64. |
| |
| ## Technical Details |
| |
| ### Architecture Components |
| - **Encoders:** Two parallel 3D ResNet-18 backbones with `InstanceNorm3D`. |
| - **Temporal Handling:** The kinetics encoder processes each time point independently; features are then aggregated via mean pooling across the time dimension. |
| - **Classification Head:** A 2-layer MLP with **40% Dropout** to prevent overfitting on complex features. |
| |
| ### Training Protocol |
| - **Optimizer:** AdamW (LR = 2e-4, Weight Decay = 1e-4). |
| - **Scheduler:** CosineAnnealingLR (T_max=50) for smooth convergence. |
| - **Loss:** **Focal Loss (gamma=2.0)** with dynamic class weighting to prioritize the rare "Benign" and "Malignant" cases. |
| - **Optimization:** Mixed Precision (AMP) and Gradient Accumulation (steps=4). |
|
|
| ## Performance & Sustainability |
| The model was trained on the **IDUN HPC Cluster**. |
| - **Carbon Tracking:** Training included environmental footprint monitoring (~0.25 kW/GPU on a low-carbon Norwegian grid). |
| - **Stability:** Evaluation shows consistent convergence across 5 folds, validated by a rigorous audit of the patient split. |
|
|
| ## Explainability (Grad-CAM) |
| The model includes a wrapper for Explainable AI. It targets the layer4 of the T2 encoder to generate 3D heatmaps, |
| helping to identify the specific anatomical regions influencing the "Malignant" prediction. |
|
|
| ## Evaluation Results |
| The primary metric used is ROC AUC (Area Under the Receiver Operating Analytic Curve). |
| Detailed performance graphs and confusion matrices for the ensemble can be found in the associated [GitHub Repository](https://github.com/THOUAN-Simon/ODELIA_Challenge_CV_DL) |
|
|
| ## How to Load the Model |
| This model requires the `MultiStreamResNet` class from the `utils.py` file. |
|
|
| ```python |
| import torch |
| from utils import MultiStreamResNet |
| |
| # Initialize model with Phase 3 dimensions |
| model = MultiStreamResNet(num_classes=3, hidden_dim=256) |
| |
| # Load weights for a specific fold |
| model.load_state_dict(torch.load("best_resnet_odelia_fold0.pth", map_location="cpu")) |
| model.eval() |
| ``` |
|
|
| ## Limitations & Ethical Considerations |
| This model is for research purposes only. It was trained on the ODELIA proprietary dataset. |
| Predictions should not be used for clinical diagnosis without professional medical supervision. |