MSLesSeg Framework: MS Lesion Segmentation

A comprehensive, research-grade deep learning framework for automated Multiple Sclerosis (MS) lesion segmentation from multimodal MRI scans using the MSLesSeg dataset.

Features

  • 7 State-of-the-Art Architectures: UNet, AttentionUNet, UNet++, SwinUNETR, UNETR, SegResNet, VNet
  • Multimodal MRI Support: FLAIR, T1, T2 channel input
  • Advanced Loss Functions: Dice, Dice+CE, Dice+Focal, Tversky, FocalTversky, Boundary, Lovasz
  • Comprehensive Evaluation: Voxel-level, lesion-level, boundary, and spatial metrics
  • Explainability: Grad-CAM, saliency maps, uncertainty estimation, attention visualization
  • Training Infrastructure: PyTorch Lightning, mixed precision, gradient checkpointing
  • Hyperparameter Optimization: Optuna integration with pruning
  • Inference Optimization: Sliding window, test-time augmentation, ensemble methods

Quick Start

Installation

# Clone repository
cd mslesseg_framework

# Install dependencies
pip install -r requirements.txt

Data Preparation

  1. Download MSLesSeg dataset from Figshare
  2. Place data in data/MSLesSeg/ directory
  3. Run preprocessing:
python run_pipeline.py --config configs/base_config.yaml --preprocess-only

Training

# Train with default config (SwinUNETR)
python run_pipeline.py --config configs/base_config.yaml

# Train specific model
python run_pipeline.py --config configs/base_config.yaml --model UNet

# Train with custom hyperparameters
python run_pipeline.py --config configs/base_config.yaml --model SwinUNETR --batch-size 1 --lr 1e-4 --epochs 400

Benchmarking

# Benchmark all architectures
python scripts/benchmark_models.py --config configs/base_config.yaml --epochs 100

# Benchmark specific models
python scripts/benchmark_models.py --config configs/base_config.yaml --models UNet SwinUNETR SegResNet

Hyperparameter Optimization

# Run Optuna optimization
python training/hyperopt.py --config configs/base_config.yaml --trials 50

Inference

# Single model inference
python run_pipeline.py --config configs/base_config.yaml --checkpoint checkpoints/best.ckpt --input data/test/ --output predictions/

# With test-time augmentation
python run_pipeline.py --config configs/base_config.yaml --checkpoint checkpoints/best.ckpt --input data/test/ --output predictions/ --tta

Project Structure

mslesseg_framework/
β”œβ”€β”€ configs/
β”‚   └── base_config.yaml          # Main configuration file
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ preprocessing/
β”‚   β”‚   └── pipeline.py           # MRI preprocessing (N4, normalization, resampling)
β”‚   └── splits/
β”‚       └── patient_split.py      # Patient-wise stratified splitting
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ architectures.py          # 7 segmentation models
β”‚   └── losses.py                 # 8 loss functions
β”œβ”€β”€ training/
β”‚   β”œβ”€β”€ trainer.py                # PyTorch Lightning training module
β”‚   β”œβ”€β”€ train.py                  # Training script
β”‚   └── hyperopt.py               # Optuna hyperparameter optimization
β”œβ”€β”€ evaluation/
β”‚   β”œβ”€β”€ metrics.py                # Comprehensive metrics (voxel + lesion + boundary)
β”‚   └── explainability.py         # Grad-CAM, uncertainty, attention
β”œβ”€β”€ inference/
β”‚   └── inference.py              # Sliding window inference, ensemble, benchmarking
β”œβ”€β”€ scripts/
β”‚   └── benchmark_models.py       # Architecture comparison script
β”œβ”€β”€ utils/
β”‚   └── config.py                 # Configuration management
β”œβ”€β”€ run_pipeline.py               # Main orchestration script
β”œβ”€β”€ requirements.txt
└── README.md

Architecture Comparison

Model Parameters Type Key Feature
SwinUNETR 15.8M Transformer Hierarchical shifted-window attention
SegResNet 18.8M CNN Residual encoder-decoder
UNet 19.2M CNN Standard with residual units
AttentionUNet 23.6M CNN Attention gates on skips
UNetPlusPlus 34.4M CNN Dense nested skips
VNet 45.7M CNN V-shaped residual
UNETR 127.4M Transformer Pure transformer encoder

Configuration

Key parameters in configs/base_config.yaml:

model:
  name: "SwinUNETR"              # Architecture selection
  feature_size: 48               # For SwinUNETR

dataset:
  patch_size: [128, 128, 128]    # Training patch size
  modalities: ["FLAIR", "T1", "T2"]

training:
  batch_size: 2
  num_epochs: 800
  optimizer:
    name: "AdamW"
    lr: 0.0001
  loss:
    name: "DiceCELoss"
  scheduler:
    name: "cosine_warmup"

Evaluation Metrics

Voxel-Level

  • Dice Similarity Coefficient (DSC)
  • Intersection over Union (IoU)
  • Precision, Recall, F1
  • Sensitivity, Specificity
  • False Positive/Discovery Rate

Lesion-Level

  • Lesion Precision, Recall, F1
  • Volume Similarity
  • Small Lesion Detection Rate
  • Number of True/False Positives

Boundary & Spatial

  • Hausdorff Distance (95th percentile)
  • Boundary F1 Score
  • Surface Dice

Citation

If you use this framework, please cite:

@article{guarnera2025mslesseg,
  title={MSLesSeg: baseline and benchmarking of a new Multiple Sclerosis Lesion Segmentation dataset},
  author={Guarnera, Maria and others},
  journal={Scientific Data},
  year={2025}
}

License

This framework is provided for research purposes. The MSLesSeg dataset is available under CC-BY license.

Acknowledgments

  • MSLesSeg dataset: Guarnera et al., Scientific Data 2025
  • MONAI: Medical Open Network for AI
  • PyTorch Lightning: Scalable PyTorch training

Generated by ML Intern

This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "Joel1810/mslesseg-framework"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support