3D_U_Net / README.md
dugoalberto's picture
Update README.md (#3)
14a06df verified
---
license: apache-2.0
---
## Model Overview
* **Model Name:** ImprovedUNet3D
* **Architecture:** 3D U-Net with residual-style encoder-decoder blocks, instance normalization, LeakyReLU activations, and dropout
* **Framework:** PyTorch
* **Input Channels:** 4 (e.g., multimodal MRI inputs)
* **Output Channels:** 4 (segmentation classes)
* **Base Filters:** 16 (scalable by multiplier in constructor)
## Intended Use
* **Primary Application:** Brain tumor segmentation on 3D MRI volumes using the BraTS 2020 dataset.
* **Users:** Medical imaging researchers, AI practitioners in healthcare.
* **Out-of-Scope:** Medical diagnosis without expert oversight. Not for real-time intraoperative use.
## Training Data
* **Dataset:** Medical Segmentation Decathlon / BraTS 2020 training and validation sets
* **Source:** `awsaf49/brats20-dataset-training-validation` on Kaggle
* **Data Volume:** \~369 cases (training + validation)
* **Preprocessing:**
* Skull stripping
* Intensity normalization per modality
* Resampling to uniform voxel size
* Patching or cropping to fixed volume shape
## Performance
| Metric | NNE Tumor Core | Peritumoral Edema | Enhancing Tumor | Background |
| ---------------- | -------------- | ----------------- | --------------- | ---------- |
| Dice Coefficient | 0.6448 | 0.7727 | 0.8026 | 0.9989 |
| Hausdorff95 (mm) | 7.6740 | 8.4238 | 5.0973 | 0.2464 |
## Limitations and Risks
* **Overfitting:** Model may not generalize to scanners or protocols outside BraTS.
* **Data Imbalance:** Rare tumor subregions may have lower performance.
* **Clinical Use:** Intended for research only; does not replace expert radiologist interpretation.
## How to Use
```python
from improved_unet3d import ImprovedUNet3D
import torch
# Instantiate model
model = ImprovedUNet3D(in_channels=4, out_channels=4, base_filters=16)
# Load pretrained weights (if available)
model.load_state_dict(torch.load("path/to/checkpoint.pth"))
model.eval()
# Inference on a single 3D volume
input_volume = torch.randn(1, 4, 128, 128, 128) # example shape
with torch.no_grad():
output = model(input_volume)
# output shape: [1, 4, 128, 128, 128]
```
## Training Details
* **Optimizer:** Adam
* **Learning Rate:** 1e-4
* **Batch Size:** 2
* **Loss Function:** Combined Dice + Cross-Entropy
* **Epochs:** 200
* **Scheduler:** Cosine annealing or Step LR
## Ethical Considerations
* **Bias:** Trained on a specific dataset; demographic coverage may be limited.
* **Privacy:** Data must be anonymized. Users should ensure HIPAA/GDPR compliance.
## Citation
If you use this model, please cite: