| --- |
| license: mit |
| tags: |
| - medical-imaging |
| - cardiac-mri |
| - segmentation |
| - attention-unet |
| - pytorch |
| datasets: |
| - ACDC |
| pipeline_tag: image-segmentation |
| --- |
| |
| # ACDC Heart Segmentation — Attention U-Net Ensemble |
|
|
| A 5-fold cross-validated **Attention U-Net** ensemble trained on the [ACDC cardiac MRI dataset](https://www.creatis.insa-lyon.fr/Challenge/acdc/) for multi-class segmentation of cardiac structures. |
|
|
| ## Model Description |
|
|
| This model segments cardiac MRI short-axis slices into 4 classes: |
| - **Class 0**: Background |
| - **Class 1**: Right Ventricle (RV) |
| - **Class 2**: Myocardium (LVM) |
| - **Class 3**: Left Ventricle (LVC) |
|
|
| ### Architecture |
| - **Base**: U-Net with Attention Gates |
| - **Input**: Single-channel grayscale MRI (256x256) |
| - **Output**: 4-class segmentation map |
| - **Training**: 5-fold cross-validation on the ACDC training set |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from model import AttentionUNet |
| |
| model = AttentionUNet(img_ch=1, output_ch=4) |
| state_dict = torch.load("fold_1_model.pth", map_location="cpu", weights_only=False) |
| if 'model_state_dict' in state_dict: |
| state_dict = state_dict['model_state_dict'] |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| # Input: [batch, 1, 256, 256] normalized to mean=0.5, std=0.5 |
| img_tensor = torch.randn(1, 1, 256, 256) |
| with torch.no_grad(): |
| output = model(img_tensor) # [batch, 4, 256, 256] |
| pred = torch.argmax(output, dim=1) # [batch, 256, 256] |
| ``` |
|
|
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `model.py` | Model architecture (AttentionUNet) | |
| | `fold_1_model.pth` - `fold_5_model.pth` | Trained weights for each CV fold | |
|
|
| ## Training Details |
|
|
| - **Dataset**: ACDC (Automated Cardiac Diagnosis Challenge) |
| - **Optimizer**: Adam |
| - **Loss**: Cross-Entropy + Dice Loss |
| - **Image Size**: 256x256 |
| - **Normalization**: (pixel - 0.5) / 0.5 |
|
|