| --- |
| license: apache-2.0 |
| --- |
| |
| ## Model Overview |
|
|
| * **Model Name:** ImprovedUNet3D |
| * **Architecture:** 3D U-Net with residual-style encoder-decoder blocks, instance normalization, LeakyReLU activations, and dropout |
| * **Framework:** PyTorch |
| * **Input Channels:** 4 (e.g., multimodal MRI inputs) |
| * **Output Channels:** 4 (segmentation classes) |
| * **Base Filters:** 16 (scalable by multiplier in constructor) |
|
|
| ## Intended Use |
|
|
| * **Primary Application:** Brain tumor segmentation on 3D MRI volumes using the BraTS 2020 dataset. |
| * **Users:** Medical imaging researchers, AI practitioners in healthcare. |
| * **Out-of-Scope:** Medical diagnosis without expert oversight. Not for real-time intraoperative use. |
|
|
| ## Training Data |
|
|
| * **Dataset:** Medical Segmentation Decathlon / BraTS 2020 training and validation sets |
| * **Source:** `awsaf49/brats20-dataset-training-validation` on Kaggle |
| * **Data Volume:** \~369 cases (training + validation) |
| * **Preprocessing:** |
|
|
| * Skull stripping |
| * Intensity normalization per modality |
| * Resampling to uniform voxel size |
| * Patching or cropping to fixed volume shape |
|
|
| ## Performance |
|
|
| | Metric | NNE Tumor Core | Peritumoral Edema | Enhancing Tumor | Background | |
| | ---------------- | -------------- | ----------------- | --------------- | ---------- | |
| | Dice Coefficient | 0.6448 | 0.7727 | 0.8026 | 0.9989 | |
| | Hausdorff95 (mm) | 7.6740 | 8.4238 | 5.0973 | 0.2464 | |
|
|
| ## Limitations and Risks |
|
|
| * **Overfitting:** Model may not generalize to scanners or protocols outside BraTS. |
| * **Data Imbalance:** Rare tumor subregions may have lower performance. |
| * **Clinical Use:** Intended for research only; does not replace expert radiologist interpretation. |
|
|
| ## How to Use |
|
|
| ```python |
| from improved_unet3d import ImprovedUNet3D |
| import torch |
| |
| # Instantiate model |
| model = ImprovedUNet3D(in_channels=4, out_channels=4, base_filters=16) |
| # Load pretrained weights (if available) |
| model.load_state_dict(torch.load("path/to/checkpoint.pth")) |
| model.eval() |
| |
| # Inference on a single 3D volume |
| input_volume = torch.randn(1, 4, 128, 128, 128) # example shape |
| with torch.no_grad(): |
| output = model(input_volume) |
| # output shape: [1, 4, 128, 128, 128] |
| ``` |
|
|
| ## Training Details |
|
|
| * **Optimizer:** Adam |
| * **Learning Rate:** 1e-4 |
| * **Batch Size:** 2 |
| * **Loss Function:** Combined Dice + Cross-Entropy |
| * **Epochs:** 200 |
| * **Scheduler:** Cosine annealing or Step LR |
|
|
| ## Ethical Considerations |
|
|
| * **Bias:** Trained on a specific dataset; demographic coverage may be limited. |
| * **Privacy:** Data must be anonymized. Users should ensure HIPAA/GDPR compliance. |
|
|
| ## Citation |
|
|
| If you use this model, please cite: |
|
|
|
|
|
|