MSLesSeg Framework: MS Lesion Segmentation
A comprehensive, research-grade deep learning framework for automated Multiple Sclerosis (MS) lesion segmentation from multimodal MRI scans using the MSLesSeg dataset.
Features
- 7 State-of-the-Art Architectures: UNet, AttentionUNet, UNet++, SwinUNETR, UNETR, SegResNet, VNet
- Multimodal MRI Support: FLAIR, T1, T2 channel input
- Advanced Loss Functions: Dice, Dice+CE, Dice+Focal, Tversky, FocalTversky, Boundary, Lovasz
- Comprehensive Evaluation: Voxel-level, lesion-level, boundary, and spatial metrics
- Explainability: Grad-CAM, saliency maps, uncertainty estimation, attention visualization
- Training Infrastructure: PyTorch Lightning, mixed precision, gradient checkpointing
- Hyperparameter Optimization: Optuna integration with pruning
- Inference Optimization: Sliding window, test-time augmentation, ensemble methods
Quick Start
Installation
# Clone repository
cd mslesseg_framework
# Install dependencies
pip install -r requirements.txt
Data Preparation
- Download MSLesSeg dataset from Figshare
- Place data in
data/MSLesSeg/directory - Run preprocessing:
python run_pipeline.py --config configs/base_config.yaml --preprocess-only
Training
# Train with default config (SwinUNETR)
python run_pipeline.py --config configs/base_config.yaml
# Train specific model
python run_pipeline.py --config configs/base_config.yaml --model UNet
# Train with custom hyperparameters
python run_pipeline.py --config configs/base_config.yaml --model SwinUNETR --batch-size 1 --lr 1e-4 --epochs 400
Benchmarking
# Benchmark all architectures
python scripts/benchmark_models.py --config configs/base_config.yaml --epochs 100
# Benchmark specific models
python scripts/benchmark_models.py --config configs/base_config.yaml --models UNet SwinUNETR SegResNet
Hyperparameter Optimization
# Run Optuna optimization
python training/hyperopt.py --config configs/base_config.yaml --trials 50
Inference
# Single model inference
python run_pipeline.py --config configs/base_config.yaml --checkpoint checkpoints/best.ckpt --input data/test/ --output predictions/
# With test-time augmentation
python run_pipeline.py --config configs/base_config.yaml --checkpoint checkpoints/best.ckpt --input data/test/ --output predictions/ --tta
Project Structure
mslesseg_framework/
βββ configs/
β βββ base_config.yaml # Main configuration file
βββ data/
β βββ preprocessing/
β β βββ pipeline.py # MRI preprocessing (N4, normalization, resampling)
β βββ splits/
β βββ patient_split.py # Patient-wise stratified splitting
βββ models/
β βββ architectures.py # 7 segmentation models
β βββ losses.py # 8 loss functions
βββ training/
β βββ trainer.py # PyTorch Lightning training module
β βββ train.py # Training script
β βββ hyperopt.py # Optuna hyperparameter optimization
βββ evaluation/
β βββ metrics.py # Comprehensive metrics (voxel + lesion + boundary)
β βββ explainability.py # Grad-CAM, uncertainty, attention
βββ inference/
β βββ inference.py # Sliding window inference, ensemble, benchmarking
βββ scripts/
β βββ benchmark_models.py # Architecture comparison script
βββ utils/
β βββ config.py # Configuration management
βββ run_pipeline.py # Main orchestration script
βββ requirements.txt
βββ README.md
Architecture Comparison
| Model | Parameters | Type | Key Feature |
|---|---|---|---|
| SwinUNETR | 15.8M | Transformer | Hierarchical shifted-window attention |
| SegResNet | 18.8M | CNN | Residual encoder-decoder |
| UNet | 19.2M | CNN | Standard with residual units |
| AttentionUNet | 23.6M | CNN | Attention gates on skips |
| UNetPlusPlus | 34.4M | CNN | Dense nested skips |
| VNet | 45.7M | CNN | V-shaped residual |
| UNETR | 127.4M | Transformer | Pure transformer encoder |
Configuration
Key parameters in configs/base_config.yaml:
model:
name: "SwinUNETR" # Architecture selection
feature_size: 48 # For SwinUNETR
dataset:
patch_size: [128, 128, 128] # Training patch size
modalities: ["FLAIR", "T1", "T2"]
training:
batch_size: 2
num_epochs: 800
optimizer:
name: "AdamW"
lr: 0.0001
loss:
name: "DiceCELoss"
scheduler:
name: "cosine_warmup"
Evaluation Metrics
Voxel-Level
- Dice Similarity Coefficient (DSC)
- Intersection over Union (IoU)
- Precision, Recall, F1
- Sensitivity, Specificity
- False Positive/Discovery Rate
Lesion-Level
- Lesion Precision, Recall, F1
- Volume Similarity
- Small Lesion Detection Rate
- Number of True/False Positives
Boundary & Spatial
- Hausdorff Distance (95th percentile)
- Boundary F1 Score
- Surface Dice
Citation
If you use this framework, please cite:
@article{guarnera2025mslesseg,
title={MSLesSeg: baseline and benchmarking of a new Multiple Sclerosis Lesion Segmentation dataset},
author={Guarnera, Maria and others},
journal={Scientific Data},
year={2025}
}
License
This framework is provided for research purposes. The MSLesSeg dataset is available under CC-BY license.
Acknowledgments
- MSLesSeg dataset: Guarnera et al., Scientific Data 2025
- MONAI: Medical Open Network for AI
- PyTorch Lightning: Scalable PyTorch training
Generated by ML Intern
This model repository was generated by ML Intern, an agent for machine learning research and development on the Hugging Face Hub.
- Try ML Intern: https://smolagents-ml-intern.hf.space
- Source code: https://github.com/huggingface/ml-intern
Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "Joel1810/mslesseg-framework"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
For non-causal architectures, replace AutoModelForCausalLM with the appropriate AutoModel class.
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support