Advanced Multi-Stream 3D-ResNet for ODELIA (Phase 3)

This repository contains the "Better Model" (Phase 3) for breast cancer classification on multi-parametric MRI. This version represents the most stable and optimized iteration, featuring a hybrid architecture designed to handle high-resolution 3D sequences with extreme memory efficiency.

Key Engineering Breakthroughs (Phase 3)

  • Instance Normalization Integration: Replaced Batch Normalization with InstanceNorm 3D. This critical fix allows the model to maintain stable gradients even with a Batch Size of 1, preventing the collapse observed in earlier versions.
  • Dynamic Temporal Fusion: Uses a Multi-Stream approach where the Kinetic stream processes a variable number of post-contrast phases.
  • Improved Temporal Aggregation: Selected Temporal Mean Pooling over Recurrent units (LSTM/RNN) to ensure a more robust global temporal summary and prevent overfitting on this specific dataset size.
  • Explainable AI (XAI): Full support for 3D Grad-CAM visualization on the T2 structural stream.

Mathematical & Architectural Refinement

  • From BatchNorm to InstanceNorm: Early versions (V3) suffered from performance collapse (AUROC ~0.55) due to the conflict between a batch size of 1 and Batch Normalization layers. Switching to Instance Normalization resolved this by stabilizing feature maps regardless of batch size.
  • Robust Temporal Aggregation: To prevent the massive overfitting observed with LSTM units (~0.0003 training loss but poor validation), we implemented Temporal Mean Pooling. This non-learnable layer effectively summarizes the enhancement kinetics without the parameter overhead of recurrent networks.

Advanced Pre-processing & Augmentation

  • Automatic ROI Cropping: Uses CropForegroundd to focus 100% of the spatial resolution on breast tissue, removing irrelevant thoracic signals.
  • Modality Dropout (20%): A custom transform that randomly masks T2 or Kinetic phases during training, preventing co-adaptation and forcing the model to learn multi-modal features.
  • Focal Loss Strategy: Implemented to address class imbalance, specifically targeting the "Benign" class detection failure observed in baseline models.

Model Description

  • Developed by: THOUAN Simon
  • Architecture: Multi-Stream 3D ResNet-18 (Custom InstanceNorm version).
  • Task: 3 Multi-class classification (Normal, Benign, Malignant).
  • Input Modalities: 1. T2 Stream: Structural 3D MRI volume.
    1. Kinetics Stream: Temporal sequence of subtraction volumes (Post_i - Pre).
  • Spatial Resolution: 128 x 128 x 64.

Technical Details

Architecture Components

  • Encoders: Two parallel 3D ResNet-18 backbones with InstanceNorm3D.
  • Temporal Handling: The kinetics encoder processes each time point independently; features are then aggregated via mean pooling across the time dimension.
  • Classification Head: A 2-layer MLP with 40% Dropout to prevent overfitting on complex features.

Training Protocol

  • Optimizer: AdamW (LR = 2e-4, Weight Decay = 1e-4).
  • Scheduler: CosineAnnealingLR (T_max=50) for smooth convergence.
  • Loss: Focal Loss (gamma=2.0) with dynamic class weighting to prioritize the rare "Benign" and "Malignant" cases.
  • Optimization: Mixed Precision (AMP) and Gradient Accumulation (steps=4).

Performance & Sustainability

The model was trained on the IDUN HPC Cluster.

  • Carbon Tracking: Training included environmental footprint monitoring (~0.25 kW/GPU on a low-carbon Norwegian grid).
  • Stability: Evaluation shows consistent convergence across 5 folds, validated by a rigorous audit of the patient split.

Explainability (Grad-CAM)

The model includes a wrapper for Explainable AI. It targets the layer4 of the T2 encoder to generate 3D heatmaps, helping to identify the specific anatomical regions influencing the "Malignant" prediction.

Evaluation Results

The primary metric used is ROC AUC (Area Under the Receiver Operating Analytic Curve). Detailed performance graphs and confusion matrices for the ensemble can be found in the associated GitHub Repository

How to Load the Model

This model requires the MultiStreamResNet class from the utils.py file.

import torch
from utils import MultiStreamResNet

# Initialize model with Phase 3 dimensions
model = MultiStreamResNet(num_classes=3, hidden_dim=256)

# Load weights for a specific fold
model.load_state_dict(torch.load("best_resnet_odelia_fold0.pth", map_location="cpu"))
model.eval()

Limitations & Ethical Considerations

This model is for research purposes only. It was trained on the ODELIA proprietary dataset. Predictions should not be used for clinical diagnosis without professional medical supervision.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including simontho/MultiStream-3DResNet-Odelia