--- license: mit tags: - medical-imaging - cardiac-mri - segmentation - attention-unet - pytorch datasets: - ACDC pipeline_tag: image-segmentation --- # ACDC Heart Segmentation — Attention U-Net Ensemble A 5-fold cross-validated **Attention U-Net** ensemble trained on the [ACDC cardiac MRI dataset](https://www.creatis.insa-lyon.fr/Challenge/acdc/) for multi-class segmentation of cardiac structures. ## Model Description This model segments cardiac MRI short-axis slices into 4 classes: - **Class 0**: Background - **Class 1**: Right Ventricle (RV) - **Class 2**: Myocardium (LVM) - **Class 3**: Left Ventricle (LVC) ### Architecture - **Base**: U-Net with Attention Gates - **Input**: Single-channel grayscale MRI (256x256) - **Output**: 4-class segmentation map - **Training**: 5-fold cross-validation on the ACDC training set ## Usage ```python import torch from model import AttentionUNet model = AttentionUNet(img_ch=1, output_ch=4) state_dict = torch.load("fold_1_model.pth", map_location="cpu", weights_only=False) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] model.load_state_dict(state_dict) model.eval() # Input: [batch, 1, 256, 256] normalized to mean=0.5, std=0.5 img_tensor = torch.randn(1, 1, 256, 256) with torch.no_grad(): output = model(img_tensor) # [batch, 4, 256, 256] pred = torch.argmax(output, dim=1) # [batch, 256, 256] ``` ## Files | File | Description | |------|-------------| | `model.py` | Model architecture (AttentionUNet) | | `fold_1_model.pth` - `fold_5_model.pth` | Trained weights for each CV fold | ## Training Details - **Dataset**: ACDC (Automated Cardiac Diagnosis Challenge) - **Optimizer**: Adam - **Loss**: Cross-Entropy + Dice Loss - **Image Size**: 256x256 - **Normalization**: (pixel - 0.5) / 0.5