3D ResNet-18 for ODELIA MRI Classification
This repository contains the weights for a 3D ResNet-18 model trained for the ODELIA dataset (multi-parametric MRI). The model was developed to classify medical images into 3 distinct classes using 5-fold cross-validation.
Model Description
- Developed by: THOUAN Simon
- Model type: 3D Convolutional Neural Network (ResNet-18 architecture)
- Framework: PyTorch & MONAI
- Task: Multi-class 3D Image Classification (3 classes)
- Input: 5 MRI sequences (Pre, Post_1, Post_2, Sub_1, T2) concatenated as channels.
- Input Size: (128, 128, 64)
Architecture Details
The model is a 3D adaptation of the ResNet architecture provided by the MONAI library:
- Blocks: Basic block
- Layers: [2, 2, 2, 2] (Equivalent to ResNet-18)
- In-planes: [64, 128, 256, 512]
- Input Channels: 5
- Spatial Dimensions: 3D
Training Procedure
The model was trained on a high-performance computing cluster (IDUN) using the following configuration:
Hyperparameters
| Parameter | Value |
|---|---|
| Optimizer | Adam |
| Learning Rate | 1e-4 |
| Loss Function | CrossEntropyLoss |
| Batch Size | 4 (Training) / 2 (Validation) |
| Epochs | 50 |
| Validation | Every 2 epochs |
Cross-Validation Strategy
The dataset was split into 5 folds (A, B, C, D, E) to ensure robustness.
- Fold 0: Val = A, Train = B+C+D+E
- Fold 1: Val = B, Train = A+C+D+E
- (Continuing for all 5 folds)
Preprocessing (MONAI Transforms)
- Resizing: All volumes resized to $128 \times 128 \times 64$.
- Normalization: Intensity scaling for each sequence.
- Concatenation: The 5 MRI sequences are stacked into a single 5-channel tensor.
Evaluation Results
The primary metric used is ROC AUC (Area Under the Receiver Operating Analytic Curve). Detailed performance graphs and confusion matrices for the ensemble can be found in the associated GitHub Repository).
How to Load the Model
import torch
from monai.networks.nets import ResNet
# Initialize architecture
model = ResNet(
block="basic",
layers=[2, 2, 2, 2],
block_inplanes=[64, 128, 256, 512],
n_input_channels=5,
num_classes=3,
spatial_dims=3
)
# Load weights
state_dict = torch.load("weights/best_resnet_odelia_fold0.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
Limitations & Ethical Considerations
This model is for research purposes only. It was trained on the ODELIA proprietary dataset. Predictions should not be used for clinical diagnosis without professional medical supervision.