Keras
ViettNguyen21's picture
Update README.md
91aa615 verified
# 🧠 Brain Tumor Segmentation — BRISC2025
> Binary segmentation of brain tumors from MRI scans using DeepLabV3+ (MobileNetV2) and EfficientNet-B1 U-Net with two-stage transfer learning.
[![TensorFlow](https://img.shields.io/badge/TensorFlow-2.x-FF6F00?logo=tensorflow)](https://tensorflow.org)
[![Python](https://img.shields.io/badge/Python-3.11-3776AB?logo=python)](https://python.org)
[![License](https://img.shields.io/badge/License-MIT-green)](LICENSE)
[![Dataset](https://img.shields.io/badge/Dataset-BRISC2025-blue)](https://brisc2025.grand-challenge.org)
---
## 📋 Table of Contents
- [Overview](#overview)
- [Dataset](#dataset)
- [Models](#models)
- [Results](#results)
- [Project Structure](#project-structure)
- [Installation](#installation)
- [Usage](#usage)
- [Training](#training)
- [Evaluation](#evaluation)
- [Citation](#citation)
---
## Overview
This project tackles **binary brain tumor segmentation** on MRI scans from the BRISC2025 challenge. Given an MRI image, the model outputs a binary mask where:
- `0` = background (healthy tissue)
- `1` = tumor region
Two architectures are implemented and compared:
1. **DeepLabV3+** with MobileNetV2 backbone — lightweight and fast
2. **EfficientNet-B1 U-Net** with ASPP bridge — higher accuracy
Both models are trained using a **two-stage transfer learning** strategy: frozen backbone warm-up followed by full fine-tuning.
---
## Dataset
### BRISC2025 — Brain Image Segmentation Challenge 2025
| Split | Images | Masks |
|-------|--------|-------|
| Train | 3,933 | 3,933 |
| Test | 860 | 860 |
**File naming convention:**
```
images/brisc2025_train_00001_gl_ax_t1.jpg
masks/ brisc2025_train_00001_gl_ax_t1.png ← same stem, different extension
```
**Directory structure:**
```
brisc2025/
├── classification_task/
│ ├── train/ {glioma/, meningioma/, no_tumor/, pituitary/}
│ └── test/ {glioma/, meningioma/, no_tumor/, pituitary/}
└── segmentation_task/
├── train/
│ ├── images/ ← MRI scans (.jpg)
│ └── masks/ ← binary tumor masks (.png)
└── test/
├── images/
└── masks/
```
**Key characteristics:**
- Modality: MRI (axial T1-weighted)
- Classes: Background (0) / Tumor (1) — binary segmentation
- Class imbalance: ~95% background pixels → requires Tversky Loss
- Input resolution: resized to `128×128` for training
> ⬇️ Download the dataset from the [BRISC2025 challenge page](https://brisc2025.grand-challenge.org).
> ⬇️ Download the tfrecord file
---
## Models
### 1. DeepLabV3+ — MobileNetV2 Backbone
```
Input (128×128×3)
└── MobileNetV2 (ImageNet pretrained)
├── Low-level features [block_2_add] → 128×128
└── High-level features [block_13_expand_relu] → 32×32
└── ASPP (rates: 1×1, 6, 12, 18 + Global Avg Pool)
└── Decoder (skip connection + upsampling)
└── Output: softmax (128×128×2)
```
### 2. EfficientNet-B1 U-Net
```
Input (128×128×3)
└── EfficientNetB1 (ImageNet pretrained)
├── s1: block2a_expand_activation
├── s2: block3a_expand_activation
├── s3: block4a_expand_activation
├── s4: block6a_expand_activation
└── s5: top_activation
└── ASPP Bridge
└── Decoder (4× skip connections)
└── Output: softmax (128×128×2)
```
### Training Strategy — 2 Stages
| Stage | Epochs | Backbone | Learning Rate |
|-------|--------|----------|---------------|
| Stage 1 — Warm-up | 20 | Frozen | Cosine Decay (1e-4 → 1e-6) |
| Stage 2 — Fine-tune | 80 | Trainable | Cosine Decay Restarts |
**Loss function — Tversky Loss** (optimized for class imbalance):
```
Tversky = TP / (TP + α·FN + β·FP) α=0.7, β=0.3
```
Higher α penalizes false negatives (missed tumors) more than false positives.
---
## Results
Results shown for **EfficientNet-B1 U-Net** (best performing model):
### Training Curves (Stage 2 — Fine-tuning)
![history_stage2_eff_net](https://cdn-uploads.huggingface.co/production/uploads/66122f4755f5219b33b5ba23/tOJJbDIMD_wwQ6dlKnq3l.png)
### Evaluation Curves
![evaluation_curves](https://cdn-uploads.huggingface.co/production/uploads/66122f4755f5219b33b5ba23/fUyb7Fx5bUX3ndE9x0NJz.png)
### Metrics on Test Set
#### Group 1 — Pixel-level Metrics
| Metric | Value |
|--------|-------|
| Accuracy | 0.9633 |
| Precision | 0.9910 |
| Recall (Sensitivity) | 0.9660 |
| Specificity | 0.9465 |
| F1-Score | 0.9784 |
#### Group 2 — Segmentation Metrics
| Metric | Value |
|--------|-------|
| **Dice Coefficient** | **0.9783 ± 0.0211** |
| **IoU / Jaccard** | **0.9582 ± 0.0359** |
| Mean IoU | 0.9582 |
| Hausdorff Distance (HD) | 5.10 px |
| HD95 | 0.15 px |
#### Group 3 — Medical Metrics
| Metric | Value |
|--------|-------|
| Volumetric Similarity (VS) | 0.9834 |
| Avg Surface Distance (ASD) | 0.05 px |
| Sensitivity @ Specificity ≥ 0.95 | 0.9639 |
#### Group 4 — AUC Metrics
| Metric | Value |
|--------|-------|
| AUC-ROC | 0.9681 |
| AUC-PR | 0.9914 |
---
## Project Structure
```
brain-tumor-segmentation/
├── train.py # Main training script
├── brain_tumor_segmentation.ipynb # Full pipeline notebook
├── load_data.ipynb # Data loading & exploration
├── model/
│ ├── deeplab_mobilenet.h5 # Trained DeepLabV3+ weights
│ └── eff_net.h5 # Trained EfficientNet-UNet weights
├── tfrecord/
│ ├── train_*.tfrecord # Training data (GZIP compressed)
│ ├── val_*.tfrecord # Validation data
│ └── test_*.tfrecord # Test data
├── assets/
│ ├── history_stage2_eff_net.png
│ └── evaluation_curves.png
└── README.md
```
---
## Installation
```bash
# Clone repository
git clone https://github.com/your-username/brain-tumor-segmentation.git
cd brain-tumor-segmentation
# Create conda environment
conda create -n seg python=3.11 -y
conda activate seg
# Install dependencies
pip install tensorflow==2.x
pip install scikit-learn scipy matplotlib opencv-python
pip install numpy pandas
```
**GPU setup (optional but recommended):**
```bash
pip install nvidia-cuda-nvcc
```
---
## Usage
### Quick Inference — Load model and predict
```python
import tensorflow as tf
import numpy as np
from PIL import Image
# Load model (no need to declare custom objects)
model = tf.keras.models.load_model('./model/eff_net.h5', compile=False)
# Prepare input image
img = Image.open('path/to/mri.jpg').resize((128, 128))
img_array = np.array(img, dtype=np.float32) / 255.0
img_array = np.expand_dims(img_array, axis=0) # (1, 128, 128, 3)
# Predict
pred = model.predict(img_array) # (1, 128, 128, 2)
mask = np.argmax(pred[0], axis=-1) # (128, 128) — 0 or 1
print(f'Tumor pixels: {mask.sum()} / {mask.size}')
```
### Batch Inference on Test Set
```python
import tensorflow as tf
import numpy as np
from pathlib import Path
def load_image(path, size=(128, 128)):
raw = tf.io.read_file(path)
image = tf.image.decode_jpeg(raw, channels=3)
image = tf.image.resize(image, size)
return tf.cast(image, tf.float32) / 255.0
# Load model
model = tf.keras.models.load_model('./model/eff_net.h5', compile=False)
# Get test image paths
test_imgs = sorted(Path('./segmentation_task/test/images').glob('*.jpg'))
predictions = []
for img_path in test_imgs:
img = load_image(str(img_path))
img = tf.expand_dims(img, 0)
pred = model.predict(img, verbose=0)
mask = np.argmax(pred[0], axis=-1)
predictions.append(mask)
print(f'Predicted {len(predictions)} masks')
```
### Visualize Predictions
```python
import matplotlib.pyplot as plt
def visualize(image_path, mask_path, model, size=(128, 128)):
# Load
img = load_image(image_path, size).numpy()
true = tf.image.decode_png(tf.io.read_file(mask_path), channels=1)
true = tf.image.resize(true, size, method='nearest').numpy().squeeze()
# Predict
pred_prob = model.predict(np.expand_dims(img, 0), verbose=0)
pred_mask = np.argmax(pred_prob[0], axis=-1)
# Overlay
overlay = img.copy()
overlay[pred_mask == 1] = [1.0, 0.2, 0.2] # red = prediction
overlay[true > 0] = [0.2, 1.0, 0.2] # green = ground truth
both = (pred_mask == 1) & (true > 0)
overlay[both] = [1.0, 1.0, 0.0] # yellow = overlap
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
titles = ['MRI Input', 'Ground Truth', 'Prediction', 'Overlay']
imgs = [img, true, pred_mask, overlay]
cmaps = [None, 'gray', 'gray', None]
for ax, title, im, cmap in zip(axes, titles, imgs, cmaps):
ax.imshow(im, cmap=cmap)
ax.set_title(title); ax.axis('off')
plt.tight_layout()
plt.show()
# Example
model = tf.keras.models.load_model('./model/eff_net.h5', compile=False)
visualize(
'./segmentation_task/test/images/brisc2025_train_00001_gl_ax_t1.jpg',
'./segmentation_task/test/masks/brisc2025_train_00001_gl_ax_t1.png',
model
)
```
---
## Training
### Step 1 — Prepare TFRecord files
The training pipeline uses TFRecord format (GZIP compressed) for efficient data loading:
```python
import tensorflow as tf
from pathlib import Path
def create_tfrecord(image_paths, mask_paths, output_path):
writer_options = tf.io.TFRecordOptions(compression_type='GZIP')
with tf.io.TFRecordWriter(output_path, options=writer_options) as writer:
for img_path, mask_path in zip(image_paths, mask_paths):
img_bytes = tf.io.read_file(img_path).numpy()
mask_bytes = tf.io.read_file(mask_path).numpy()
feature = {
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_bytes])),
'mask' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask_bytes])),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
```
### Step 2 — Run training
```bash
python train.py
```
**Training configuration:**
```python
BATCH_SIZE = 8
EPOCHS_STAGE_1 = 20 # Frozen backbone warm-up
EPOCHS_STAGE_2 = 80 # Full fine-tuning
TARGET_SIZE = (128, 128)
```
**Optimizer — Adam with Cosine Decay Restarts:**
```python
lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
initial_learning_rate=1e-4,
first_decay_steps=1000,
t_mul=2.0,
m_mul=0.9,
alpha=1e-6,
)
optimizer = tf.keras.optimizers.Adam(
learning_rate=lr_schedule,
clipnorm=1.0,
)
```
**Augmentation strategy:**
- 50% random crop (50–100% of original size)
- 50% focused crop (centered on rare/tumor region)
- Random horizontal flip
- Random brightness (±15%), contrast (±15%), saturation (±15%)
---
## Evaluation
Run full evaluation with all metric groups:
```python
import numpy as np
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score
model = tf.keras.models.load_model('./model/eff_net.h5', compile=False)
# Collect predictions
y_true_list, y_pred_list = [], []
for images, masks in test_ds:
preds = model.predict(images, verbose=0)
preds_bin = np.argmax(preds, axis=-1, keepdims=True)
y_true_list.append(masks.numpy())
y_pred_list.append(preds_bin)
y_true = np.concatenate(y_true_list).flatten()
y_pred = np.concatenate(y_pred_list).flatten()
# Metrics
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
dice = (2*tp) / (2*tp + fp + fn + 1e-6)
iou = tp / (tp + fp + fn + 1e-6)
print(f'Dice : {dice:.4f}')
print(f'IoU : {iou:.4f}')
```
---
## Requirements
```
tensorflow>=2.13
numpy
scikit-learn
scipy
matplotlib
opencv-python
Pillow
```
---
## Citation
If you use this code or models in your research, please cite:
```bibtex
@misc{brisc2025-segmentation,
title = {Brain Tumor Segmentation on BRISC2025 using DeepLabV3+ and EfficientNet-UNet},
author = {Your Name},
year = {2025},
url = {https://github.com/your-username/brain-tumor-segmentation}
}
```
**Dataset:**
```bibtex
@dataset{brisc2025,
title = {BRISC2025: Brain Image Segmentation Challenge 2025},
year = {2025},
url = {https://brisc2025.grand-challenge.org}
}
```
---
## License
This project is licensed under the MIT License — see [LICENSE](LICENSE) for details.
---
## Acknowledgements
- [BRISC2025 Challenge](https://brisc2025.grand-challenge.org) for the dataset
- [TensorFlow / Keras](https://tensorflow.org) for the deep learning framework
- DeepLabV3+ architecture: [Chen et al., 2018](https://arxiv.org/abs/1802.02611)
- EfficientNet: [Tan & Le, 2019](https://arxiv.org/abs/1905.11946)