metadata
language:
- en
tags:
- medical-segmentation
- pytorch
- brats
- unet
- mri
- brain-tumor
license: mit
metrics:
- val_loss: 0.008356
model-index:
- name: PicoUNet
results: []
PicoUNet for BraTS 2020 Tumor Segmentation
This model is a deep learning model for 3D/2D brain tumor segmentation, trained on the BraTS 2020 dataset. It uses a PicoUNet architecture (a lightweight variant of U-Net) to segment brain tumors from MRI scans.
Model Description
- Architecture: PicoUNet (Encoder-Decoder with Skip Connections)
- Input Modalities: 2 Channels (FLAIR, T1Ce) -> Note: Standard BraTS has 4 (T1, T1Ce, T2, FLAIR), this model uses a subset.
- Output Classes: 4 classes (mapped from BraTS labels)
- Class 0: Background
- Class 1: Necrotic and Non-Enhancing Tumor Core (NCR/NET) - BraTS label 1
- Class 2: Edema (ED) - BraTS label 2
- Class 3: Enhancing Tumor (ET) - BraTS label 4 (Label 4 is mapped to 3)
Dataset
The model fits on the MICCAI BraTS 2020 Challenge Data. The dataset consists of multimodal MRI scans of glioblastoma (GBM/HGG) and lower-grade glioma (LGG). All volumes are registered to the same template space (SRI24) and interpolated to the same resolution ($1 mm^3$).
Preprocessing
- Slicing: 2D axial slices were extracted.
- Resizing: Images were resized to 128x128.
- Normalization: Min-Max normalization per slice.
Training Configuration
Auto-generated from the model checkpoint:
- Epochs Trained: 99
- Validation Loss: 0.008356
- Experiment Name: Brats-UNet
- Batch Size: 32
- Learning Rate: 0.03
- Seed: 42
- Optimizer: SGD (Momentum 0.9)
- Loss Function: CrossEntropyLoss (or DiceLoss depending on experiment)
Usage
To use this model, you need to preprocess your input image similarly to the training pipeline (extract 2 channels: FLAIR and T1Ce, resize to 128x128).
import torch
from miccai_brats.models.unet.unet import PicoUNet
# 1. Instantiate Model
model = PicoUNet(in_channels=2, num_classes=4)
# 2. Load Checkpoint
checkpoint = torch.load("best_model.pth", map_location="cpu")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# 3. Inference
input_tensor = torch.randn(1, 2, 128, 128) # Fake Batch of [FLAIR, T1Ce]
with torch.no_grad():
output = model(input_tensor)
prediction = torch.argmax(output, dim=1)
print("Predicted shape:", prediction.shape) # [1, 128, 128]
Intended Use & Limitations
- Research Use Only: This model is for educational and research purposes. It is not intended for clinical use.
- Performance: This is a "Pico" variant, optimized for speed and low compute, not for state-of-the-art accuracy.
Citations
If you use the BraTS dataset, please cite:
- Menze et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE TMI 2015.
- Bakas et al. "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data 2017.
- Bakas et al. "Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge", arXiv 2018.