PicoUNet for BraTS 2020 Tumor Segmentation
This model is a deep learning model for 3D/2D brain tumor segmentation, trained on the BraTS 2020 dataset. It uses a PicoUNet architecture (a lightweight variant of U-Net) to segment brain tumors from MRI scans.
Model Description
- Architecture: PicoUNet (Encoder-Decoder with Skip Connections)
- Input Modalities: 2 Channels (FLAIR, T1Ce) -> Note: Standard BraTS has 4 (T1, T1Ce, T2, FLAIR), this model uses a subset.
- Output Classes: 4 classes (mapped from BraTS labels)
- Class 0: Background
- Class 1: Necrotic and Non-Enhancing Tumor Core (NCR/NET) - BraTS label 1
- Class 2: Edema (ED) - BraTS label 2
- Class 3: Enhancing Tumor (ET) - BraTS label 4 (Label 4 is mapped to 3)
Dataset
The model fits on the MICCAI BraTS 2020 Challenge Data. The dataset consists of multimodal MRI scans of glioblastoma (GBM/HGG) and lower-grade glioma (LGG). All volumes are registered to the same template space (SRI24) and interpolated to the same resolution ($1 mm^3$).
Preprocessing
- Slicing: 2D axial slices were extracted.
- Resizing: Images were resized to 128x128.
- Normalization: Min-Max normalization per slice.
Training Configuration
Auto-generated from the model checkpoint:
- Epochs Trained: 99
- Validation Loss: 0.008356
- Experiment Name: Brats-UNet
- Batch Size: 32
- Learning Rate: 0.03
- Seed: 42
- Optimizer: SGD (Momentum 0.9)
- Loss Function: CrossEntropyLoss (or DiceLoss depending on experiment)
Usage
To use this model, you need to preprocess your input image similarly to the training pipeline (extract 2 channels: FLAIR and T1Ce, resize to 128x128).
import torch
from miccai_brats.models.unet.unet import PicoUNet
# 1. Instantiate Model
model = PicoUNet(in_channels=2, num_classes=4)
# 2. Load Checkpoint
checkpoint = torch.load("best_model.pth", map_location="cpu")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# 3. Inference
input_tensor = torch.randn(1, 2, 128, 128) # Fake Batch of [FLAIR, T1Ce]
with torch.no_grad():
output = model(input_tensor)
prediction = torch.argmax(output, dim=1)
print("Predicted shape:", prediction.shape) # [1, 128, 128]
Intended Use & Limitations
- Research Use Only: This model is for educational and research purposes. It is not intended for clinical use.
- Performance: This is a "Pico" variant, optimized for speed and low compute, not for state-of-the-art accuracy.
Citations
If you use the BraTS dataset, please cite:
- Menze et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE TMI 2015.
- Bakas et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data 2017.
- Bakas et al. "Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge", arXiv 2018.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support