YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

SegFormer3D

SegFormer3D is a novel and efficient transformer-based architecture designed specifically for 3D medical image segmentation. It extends the successful 2D SegFormer architecture to handle volumetric medical data while maintaining computational efficiency and strong segmentation performance.

PWC

Model Description

SegFormer3D introduces several key innovations for efficient 3D medical image segmentation:

  • Hierarchical 3D Feature Learning: Uses a multi-scale transformer encoder with progressively reduced sequence lengths for efficient processing of volumetric data
  • Efficient Self-Attention: Implements spatially-reduced attention mechanism adapted for 3D, reducing computational complexity while maintaining performance
  • All-MLP 3D Decoder: Lightweight decoder that effectively fuses multi-scale features through simple MLP layers
  • Memory-Efficient Design: Optimized architecture that can process full 3D volumes without patch-based inference

The model achieves state-of-the-art performance on multiple 3D medical segmentation benchmarks while being computationally efficient.

SegFormer3D Architecture

Training and Evaluation

The model was trained and evaluated on several medical imaging datasets:

ACDC Dataset

  • Task: Cardiac MRI segmentation
  • Classes: Left ventricle, right ventricle, and myocardium
  • Performance:
    • Dice Score: 90.96%

ACDC Results

BraTS 2017

  • Task: Brain tumor segmentation
  • Classes: Enhancing tumor, tumor core, and whole tumor
  • Performance: - Average Dice: 82.1%

    Brats Results

Synapse Dataset

  • Task: Multi-organ segmentation
  • Classes: 8 abdominal organs
  • Performance:
    • Average Dice: 82.15%

Synapse Results

Usage

from transformers import SegFormer3DConfig, SegFormer3DModel
import torch

# Initialize configuration
config = SegFormer3DConfig(
    in_channels=4,  # Number of input channels
    num_classes=3,  # Number of segmentation classes
    # Model architecture parameters
    embed_dims=[32, 64, 160, 256],
    num_heads=[1, 2, 5, 8],
    depths=[2, 2, 2, 2],
    sr_ratios=[4, 2, 1, 1]
)

# Initialize model
model = SegFormer3DModel(config)

# Example forward pass
batch_size = 1
depth, height, width = 128, 128, 128  # Example input dimensions
x = torch.randn(batch_size, config.in_channels, depth, height, width)
outputs = model(x)

# Get segmentation logits
logits = outputs.logits  # Shape: (batch_size, num_classes, D, H, W)

Limitations

  • Input dimensions must be properly configured to ensure valid spatial dimensions after each stage
  • Memory requirements increase with input volume size
  • Performance may vary on different medical imaging modalities
  • Assumes consistent voxel spacing in input volumes

Training Tips

  1. Input Preprocessing:

    • Normalize input volumes to [0, 1] range
    • Consider standardization per modality
    • Use appropriate data augmentation for medical volumes
  2. Training Strategy:

    • Start with a smaller learning rate (1e-4 recommended)
    • Use gradient clipping to stabilize training
    • Consider mixed precision training for memory efficiency
  3. Memory Management:

    • Adjust batch size based on available GPU memory
    • Use gradient checkpointing if needed
    • Consider input volume size vs. model depth trade-offs

Citation

@InProceedings{Perera_2024_CVPR,
  title={Segformer3d: an efficient transformer for 3d medical image segmentation},
  author={Perera, Shehan and Navard, Pouyan and Yilmaz, Alper},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={4981--4988},
  year={2024}
}

Acknowledgements

This implementation is based on the original SegFormer architecture by Xie et al. and extends it to efficient 3D medical image segmentation. We thank the authors for their valuable contributions to the field.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support