BraTS2020 Brain Tumor Segmentation: SegFormer (MiT-B3)

This repository contains the high-performance, state-of-the-art SegFormer (MiT-B3) model trained on the BraTS 2020 (Brain Tumor Segmentation Challenge) dataset. The model achieves publication-grade volumetric accuracy and exceptional segmentation boundaries using hierarchical self-attention.

  • Developer: Bilge
  • Model Architecture: SegFormer-B3 (Hierarchical MixTransformer Encoder + Lightweight MLP Decoder)
  • Framework: PyTorch (Stable FP32 Stable Precision training)
  • Target Task: Multimodal Brain Tumor Segmentation (FLAIR, T1, T1ce, T2)
  • Classes:
    • NCR/NET (Necrotic & Non-Enhancing Tumor Core)
    • Edema (Peritumoral Edema)
    • ET (Enhancing Active Tumor)

🔗 GitHub Repository: https://github.com/your-username/your-repo-name (Replace with your actual GitHub repo URL)


🏆 Final Performance Metrics

Trained for 30 epochs in full FP32 precision under a single-GPU stable environment with Exponential Moving Average (EMA) shadow weights (decay = 0.999), our Segformer model achieves outstanding per-volume test scores:

Metric Score (Dice Coefficient) Clinical Significance
BraTS Mean Score 82.65% Highly robust overall tumor localization.
Whole Tumor (WT) 88.87% Outstanding mapping of peripheral vasogenic edema boundaries.
Tumor Core (TC) 83.72% Superior localization of the necrotic core and active tumor zones.
Enhancing Tumor (ET) 75.36% Exceptional capture of active enhancing ring boundaries.

🧠 Architectural Highlights: Hierarchical MixTransformer (MiT-B3)

Unlike standard Vision Transformers (ViT) that generate single-scale feature representations and suffer from high computational complexity, SegFormer utilizes a Hierarchical MixTransformer (MiT-B3) encoder:

Input Image ────► [ Stage 1 (H/4) ] ───► [ Stage 2 (H/8) ] ───► [ Stage 3 (H/16) ] ───► [ Stage 4 (H/32) ]
                         │                       │                      │                       │
                         └───────────────────────┼──────────────┬───────┘                       │
                                                 ▼              ▼                               ▼
                                            [ All Stages Concatenated & Fed into MLP Decoder Head ]

Key Advantages:

  1. Multi-Scale Feature Learning: Generates hierarchical feature maps at different resolutions ($1/4, 1/8, 1/16, 1/32$), capturing both dense localization clues and global context.
  2. Positional-Encoding Free (Mix-FFN): Uses Depth-wise Convolutions inside the Feed-Forward Networks to inject leakage positional information, allowing the model to perform seamless inference at resolutions different from training.
  3. Overlapped Patch Merging: Maintains spatial continuity at patch boundaries, dramatically improving segmentation fidelity on tiny tumor margins (e.g. Enhancing Tumor).

🚀 How to Load and Predict in PyTorch

To use this model, ensure you clone your GitHub repository containing the source code (src/ folder), and download the best.pt weights directly from this Hugging Face repository:

import torch
from huggingface_hub import hf_hub_download
from src.models import build_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Rebuild the model structure from local code
model = build_model(
    model_name="segformer_b3",
    in_channels=4,
    out_channels=3,
    encoder_weights=None,
    pretrained_path=None
).to(device)

# 2. Download and load the weights from Hugging Face
ckpt_path = hf_hub_download(repo_id="your-username/brats2020-segformer-mit3", filename="best.pt")
ckpt = torch.load(ckpt_path, map_location=device)
state_dict = ckpt.get("model_state_dict", ckpt)

# Clean DataParallel prefixes if present
cleaned_state = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state)

# 3. Enter Eval Mode
model.eval()

# Now the model is ready for %82.65 volumetric Segformer inference!

📂 Dataset & Preprocessing Information

The model expects normalized multi-modal 2D slices of shape (4, 240, 240) containing:

  1. FLAIR (Fluid-Attenuated Inversion Recovery)
  2. T1 (T1-weighted)
  3. T1ce (T1-weighted Contrast-Enhanced)
  4. T2 (T2-weighted)

Prior to model input, foreground regions (voxels $> 0$) are normalized using channel-wise z-score normalization to ensure high numerical stability.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support