Multi-Stream Swin-UNETR for ODELIA MRI Classification (Phase 2)

This repository contains the weights and inference code for an advanced Multi-Stream Vision Transformer architecture. Unlike the baseline, this model leverages the temporal dynamics of contrast-enhanced MRI and uses a Swin-Transformer backbone for high-dimensional feature extraction.

Key Improvements (Phase 2)

  • Architecture: Transition from CNN (ResNet) to Swin-UNETR (Swin Transformer for 3D Medical Images).
  • Multi-Stream Logic: Two parallel encoders processing different MRI modalities:
    1. T2 Stream: Captures structural morphology.
    2. Kinetics Stream: Captures temporal contrast enhancement by processing (Post-Contrast - Pre-Contrast) subtractions.
  • Explainability: Integrated Grad-CAM support for both T2 and Kinetic streams.
  • Training Strategy: Implemented Focal Loss for class imbalance, Gradient Accumulation, and Mixed Precision (AMP) for efficient HPC training.

Model Description

  • Developed by: THOUAN Simon
  • Architecture: Custom MultiStreamSwin (based on MONAI Swin-UNETR).
  • Task: Multi-class 3D Image Classification (3 classes: Normal, Benign, Malignant).
  • Inputs: - T2-weighted volume.
    • Kinetic sequence: T=4 phases of subtractions (Post - Pre).
  • Input Size: (128, 128, 64) per phase.

Technical Architecture details

The model uses two identical Swin-UNETR backbones initialized with pre-trained weights from the BTCV dataset.

  • Feature Size: 48.
  • Fusion: Features from all temporal phases and the T2 stream are concatenated into a high-dimensional vector (3840 features) before being passed to a Dropout-regularized MLP head.

Training Details

  • Loss Function: Focal Loss (gamma = 2.0) with dynamic class weighting.
  • Optimization: Adam (LR = 1e-4).
  • Batch Size Simulation: Gradient Accumulation (steps=4) to simulate a larger effective batch size on GPU.
  • Hardware: Trained on the IDUN HPC Cluster (NTNU - Norway).

Explainable AI (XAI)

This model is designed for transparency. The inference script generates 3D Grad-CAM Heatmaps for both the structural (T2) and functional (Kinetics) inputs, allowing clinicians to visualize the features driving the prediction.

Evaluation Results

The primary metric used is ROC AUC (Area Under the Receiver Operating Analytic Curve). Detailed performance graphs and confusion matrices for the ensemble can be found in the associated GitHub Repository.

How to use

The architecture requires the MultiStreamSwin and PrepareTemporalKineticsd classes provided in the utils.py file of this repository.

import torch
from utils import MultiStreamSwin

# Initialize the Multi-Stream Transformer
model = MultiStreamSwin(num_classes=3)

# Load Fold 0 weights
model.load_state_dict(torch.load("weights/best_swin_odelia_fold0.pth", map_location="cpu"))
model.eval()

# Example: Inference on T2 and Kinetics tensors
# output = model(t2_tensor, kinetics_tensor)

Limitations & Ethical Considerations

This model is for research purposes only. It was trained on the ODELIA proprietary dataset. Predictions should not be used for clinical diagnosis without professional medical supervision.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including simontho/Swin-UNETR-Odelia