3D ResNet-18 for ODELIA MRI Classification

This repository contains the weights for a 3D ResNet-18 model trained for the ODELIA dataset (multi-parametric MRI). The model was developed to classify medical images into 3 distinct classes using 5-fold cross-validation.

Model Description

  • Developed by: THOUAN Simon
  • Model type: 3D Convolutional Neural Network (ResNet-18 architecture)
  • Framework: PyTorch & MONAI
  • Task: Multi-class 3D Image Classification (3 classes)
  • Input: 5 MRI sequences (Pre, Post_1, Post_2, Sub_1, T2) concatenated as channels.
  • Input Size: (128, 128, 64)

Architecture Details

The model is a 3D adaptation of the ResNet architecture provided by the MONAI library:

  • Blocks: Basic block
  • Layers: [2, 2, 2, 2] (Equivalent to ResNet-18)
  • In-planes: [64, 128, 256, 512]
  • Input Channels: 5
  • Spatial Dimensions: 3D

Training Procedure

The model was trained on a high-performance computing cluster (IDUN) using the following configuration:

Hyperparameters

Parameter Value
Optimizer Adam
Learning Rate 1e-4
Loss Function CrossEntropyLoss
Batch Size 4 (Training) / 2 (Validation)
Epochs 50
Validation Every 2 epochs

Cross-Validation Strategy

The dataset was split into 5 folds (A, B, C, D, E) to ensure robustness.

  • Fold 0: Val = A, Train = B+C+D+E
  • Fold 1: Val = B, Train = A+C+D+E
  • (Continuing for all 5 folds)

Preprocessing (MONAI Transforms)

  • Resizing: All volumes resized to $128 \times 128 \times 64$.
  • Normalization: Intensity scaling for each sequence.
  • Concatenation: The 5 MRI sequences are stacked into a single 5-channel tensor.

Evaluation Results

The primary metric used is ROC AUC (Area Under the Receiver Operating Analytic Curve). Detailed performance graphs and confusion matrices for the ensemble can be found in the associated GitHub Repository).

How to Load the Model

import torch
from monai.networks.nets import ResNet

# Initialize architecture
model = ResNet(
    block="basic", 
    layers=[2, 2, 2, 2], 
    block_inplanes=[64, 128, 256, 512],
    n_input_channels=5, 
    num_classes=3, 
    spatial_dims=3
)

# Load weights
state_dict = torch.load("weights/best_resnet_odelia_fold0.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

Limitations & Ethical Considerations

This model is for research purposes only. It was trained on the ODELIA proprietary dataset. Predictions should not be used for clinical diagnosis without professional medical supervision.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including simontho/3DResNet-Odelia