BraTS2020 Brain Tumor Segmentation: SegFormer (MiT-B3)
This repository contains the high-performance, state-of-the-art SegFormer (MiT-B3) model trained on the BraTS 2020 (Brain Tumor Segmentation Challenge) dataset. The model achieves publication-grade volumetric accuracy and exceptional segmentation boundaries using hierarchical self-attention.
- Developer: Bilge
- Model Architecture: SegFormer-B3 (Hierarchical MixTransformer Encoder + Lightweight MLP Decoder)
- Framework: PyTorch (Stable FP32 Stable Precision training)
- Target Task: Multimodal Brain Tumor Segmentation (FLAIR, T1, T1ce, T2)
- Classes:
NCR/NET(Necrotic & Non-Enhancing Tumor Core)Edema(Peritumoral Edema)ET(Enhancing Active Tumor)
🔗 GitHub Repository: https://github.com/your-username/your-repo-name (Replace with your actual GitHub repo URL)
🏆 Final Performance Metrics
Trained for 30 epochs in full FP32 precision under a single-GPU stable environment with Exponential Moving Average (EMA) shadow weights (decay = 0.999), our Segformer model achieves outstanding per-volume test scores:
| Metric | Score (Dice Coefficient) | Clinical Significance |
|---|---|---|
| BraTS Mean Score | 82.65% | Highly robust overall tumor localization. |
| Whole Tumor (WT) | 88.87% | Outstanding mapping of peripheral vasogenic edema boundaries. |
| Tumor Core (TC) | 83.72% | Superior localization of the necrotic core and active tumor zones. |
| Enhancing Tumor (ET) | 75.36% | Exceptional capture of active enhancing ring boundaries. |
🧠 Architectural Highlights: Hierarchical MixTransformer (MiT-B3)
Unlike standard Vision Transformers (ViT) that generate single-scale feature representations and suffer from high computational complexity, SegFormer utilizes a Hierarchical MixTransformer (MiT-B3) encoder:
Input Image ────► [ Stage 1 (H/4) ] ───► [ Stage 2 (H/8) ] ───► [ Stage 3 (H/16) ] ───► [ Stage 4 (H/32) ]
│ │ │ │
└───────────────────────┼──────────────┬───────┘ │
▼ ▼ ▼
[ All Stages Concatenated & Fed into MLP Decoder Head ]
Key Advantages:
- Multi-Scale Feature Learning: Generates hierarchical feature maps at different resolutions ($1/4, 1/8, 1/16, 1/32$), capturing both dense localization clues and global context.
- Positional-Encoding Free (Mix-FFN): Uses Depth-wise Convolutions inside the Feed-Forward Networks to inject leakage positional information, allowing the model to perform seamless inference at resolutions different from training.
- Overlapped Patch Merging: Maintains spatial continuity at patch boundaries, dramatically improving segmentation fidelity on tiny tumor margins (e.g. Enhancing Tumor).
🚀 How to Load and Predict in PyTorch
To use this model, ensure you clone your GitHub repository containing the source code (src/ folder), and download the best.pt weights directly from this Hugging Face repository:
import torch
from huggingface_hub import hf_hub_download
from src.models import build_model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Rebuild the model structure from local code
model = build_model(
model_name="segformer_b3",
in_channels=4,
out_channels=3,
encoder_weights=None,
pretrained_path=None
).to(device)
# 2. Download and load the weights from Hugging Face
ckpt_path = hf_hub_download(repo_id="your-username/brats2020-segformer-mit3", filename="best.pt")
ckpt = torch.load(ckpt_path, map_location=device)
state_dict = ckpt.get("model_state_dict", ckpt)
# Clean DataParallel prefixes if present
cleaned_state = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state)
# 3. Enter Eval Mode
model.eval()
# Now the model is ready for %82.65 volumetric Segformer inference!
📂 Dataset & Preprocessing Information
The model expects normalized multi-modal 2D slices of shape (4, 240, 240) containing:
- FLAIR (Fluid-Attenuated Inversion Recovery)
- T1 (T1-weighted)
- T1ce (T1-weighted Contrast-Enhanced)
- T2 (T2-weighted)
Prior to model input, foreground regions (voxels $> 0$) are normalized using channel-wise z-score normalization to ensure high numerical stability.