Neuro-AI: AI-Driven MS Lesion Analysis Framework

Ventricles & WMH Segmentation:

Pre-trained models for ventricles and white matter hyperintensity (WMH) segmentation with explicit distinction between normal periventricular changes (normal WMH) and pathological lesions (abnormal WMH).

Model Description

This repository should contain 6 pre-trained deep learning models (6 architectures) for automated, simultaneous Ventricles and WMH segmentation from FLAIR MRI images. The models implement a novel four-class approach that distinguishes between:

  • Class 0: Background
  • Class 1: Ventricular system
  • Class 2: Normal WMH (aging-related periventricular changes)
  • Class 3: Abnormal WMH (pathologically significant lesions)

Model Architectures

Architecture mean DSC mean IoU mean HD95 mean Precision mean Recall
Baseline U-Net (WCE) 0.714 ±0.018 0.601 ±0.020 6.50 ±0.46 0.616 ±0.020 0.937 ±0.002
Baseline Pix2Pix (WCE) 0.823 ±0.011 0.721 ±0.014 5.31 ±0.20 0.805 ±0.005 0.848 ±0.010
Baseline Pix2Pix (UFL) 0.817 ±0.010 0.714 ±0.012 5.50 ±0.35 0.791 ±0.021 0.850 ±0.019
Pix2Pix + Attention Discriminator 0.824 ±0.008 0.723 ±0.011 5.23 ±0.31 0.807 ±0.014 0.843 ±0.009
Pix2Pix + Adaptive Hybrid Loss 0.844 ±0.002 0.750 ±0.003 4.81 ±0.05 0.845 ±0.009 0.843 ±0.010
Pix2Pix + Attention Discriminator + Adaptive Hybrid Loss 0.852 ±0.004 0.760 ±0.006 4.87 ±0.13 0.856 ±0.006 0.850 ±0.006

Recommended: Pix2Pix + Attention Discriminator + Adaptive Hybrid Loss (V5) for best performance

Repository Structure

results_fold_2_var_5_bet_zscore/
└── models/standard_4class/fold_2
    └── best_dice_generator.h5      # 4-Class: Background, Ventricles, Normal, Abnormal

Quick Start

Installation

pip install huggingface_hub tensorflow numpy nibabel

Download Models

from huggingface_hub import hf_hub_download

# Download best performing model (V5)
model_path = hf_hub_download(
    repo_id="Bawil/neuro-ai",
    filename="results_fold_2_var_5_bet_zscore/models/standard_4class/fold_2/best_dice_generator.h5"
)

# Load model
from tensorflow.keras.models import load_model
model = load_model(model_path)

Inference Example

import numpy as np
from tensorflow.keras.models import load_model

# Load pre-trained model
model = load_model(model_path)

# Prepare input (256x256 grayscale FLAIR MRI, normalized)
# input_image shape: (batch_size, 256, 256, 1)
input_image = preprocess_flair(your_flair_image)

# Run inference
predictions = model.predict(input_image)

# Get class predictions
predicted_classes = np.argmax(predictions, axis=-1)
# 0: Background
# 1: Ventricles
# 2: Normal WMH (periventricular)
# 3: Abnormal WMH (pathological)

# Extract pathological lesions only
abnormal_mask = (predicted_classes == 2).astype(np.uint8)

Training Data

Dataset Composition

  • Local Dataset: 300 MS patients (6,000 FLAIR MRI slices)

    • Demographics: 78 males, 222 females
    • Age range: 18-68 years
    • Scanner: 1.5-Tesla TOSHIBA Vantage
  • Public Dataset: MSSEG2016 (15 patients, 750 FLAIR slices)

Annotations

  • Expert annotations by board-certified neuroradiologists (20+ years experience)
  • Four-class labeling: Background, Ventricles, Normal WMH, Abnormal WMH
  • Approved by Ethics Committee (IR.TBZMED.REC.1402.902)

Data Split

  • Training: 70% patients (local)
  • Validation: 10% patients (local)
  • Testing: 20% patients (local) + 40% patients (public)
  • Strategy: Patient-level stratified split (no slice-level leakage)

Model Training

Configuration

  • Framework: TensorFlow 2.11, Keras
  • Optimizer: Adam (learning rate: 0.0002)
  • Loss Functions:
    • Option 1: Weighted categorical cross-entropy
    • Option 2: Unified Focal Dice
  • Epochs: 60 (with early stopping)
  • Batch Size: 4
  • Input Size: 256×256×1

Hardware

  • GPU: NVIDIA RTX 3060 (12GB VRAM)
  • Training Time: 3-4 hours per model (5-fold CV)
  • Inference Time: ~35-40ms per image

Model Performance

Use Cases

Clinical Applications

  • MS Lesion Quantification: Accurate measurement of disease burden
  • Differential Diagnosis: Distinguish pathological from normal aging
  • Longitudinal Monitoring: Track disease progression over time
  • Treatment Response: Evaluate therapeutic efficacy
  • Radiological Reporting: Reduce false positive alerts

Research Applications

  • Baseline Comparisons: Standardized evaluation framework
  • Method Development: Foundation for advanced segmentation approaches
  • Multi-center Studies: Protocol for broader validation
  • Reproducible Research: Complete implementation available

Limitations

  • Single Modality: Trained on FLAIR MRI only
  • Scanner Specificity: Primarily 1.5T TOSHIBA data
  • Disease Focus: Optimized for MS patients
  • 2D Segmentation: Slice-by-slice processing (no 3D context)
  • Resolution: Fixed 256×256 input size

Model Card

Intended Use

  • Primary: Automated WMH segmentation for research and clinical decision support
  • Users: Radiologists, neurologists, researchers, AI developers
  • Out-of-scope: Not FDA/CE approved; not for standalone clinical diagnosis

Ethical Considerations

  • Privacy: All data anonymized per HIPAA/GDPR standards
  • Bias: Limited scanner/protocol diversity may affect generalization
  • Clinical Validation: Requires expert review before clinical use
  • Transparency: Complete methodology and code openly available

Model Card Authors

Mahdi Bashiri Bawil, Mousa Shamsi, Abolhassan Shakeri Bavil

Citation

@article{bawil2026neuro,
  title={AI-Driven Multi-Parametric MS Lesion Analysis from T2-FLAIR Imaging: A Clinical Decision Support Framework for Neuroradiology},
  author={Bawil, Mahdi Bashiri and Shamsi, Mousa and Bavil, Abolhassan Shakeri},
  year={2026},
  note={Models: https://huggingface.co/Bawil/neuro-ai}
}

License

MIT License - See LICENSE

Additional Resources

Acknowledgments

  • Golgasht Medical Imaging Center, Tabriz, Iran for providing clinical data
  • Expert neuroradiologists for manual annotations
  • Ethics Committee approval: IR.TBZMED.REC.1402.902

Keywords: Artificial intelligence (AI), multiple sclerosis, neuroradiology, clinical decision support, automated lesion analysis, deep learning, clinical AI

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support