Spaces:
Sleeping
Sleeping
JANGALA SAKETH commited on
Upload 6 files
Browse files- .gitattributes +1 -0
- README.md +89 -0
- app.py +245 -0
- requirements.txt +9 -0
- unet3d_model.pth +3 -0
- unet_model.py +95 -0
- utils.py +123 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
unet3d_model.pth filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Brain Tumor Segmentation
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.48.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Sample Data
|
| 14 |
+
|
| 15 |
+
You can download sample NIfTI files for two patients to test the model from this Google Drive link:
|
| 16 |
+
|
| 17 |
+
[Sample Data (Google Drive)](https://drive.google.com/drive/folders/19LzKOcoIrWQhwY91e_kn644AcQi4tl8z?usp=sharing)
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
title: Brain Tumor Segmentation
|
| 21 |
+
emoji: 🧠
|
| 22 |
+
colorFrom: blue
|
| 23 |
+
colorTo: pink
|
| 24 |
+
sdk: streamlit
|
| 25 |
+
sdk_version: 1.48.1
|
| 26 |
+
app_file: app.py
|
| 27 |
+
pinned: false
|
| 28 |
+
license: mit
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
# Brain Tumor Segmentation App
|
| 32 |
+
|
| 33 |
+
<p align="center">
|
| 34 |
+
<img src="https://img.shields.io/badge/Streamlit-Online-brightgreen" alt="Streamlit">
|
| 35 |
+
<a href="https://huggingface.co/spaces/saketh-005/brain-tumor-segmentation"><img src="https://img.shields.io/badge/HuggingFace-Spaces-yellow" alt="Hugging Face Spaces"></a>
|
| 36 |
+
</p>
|
| 37 |
+
|
| 38 |
+
This project is a web application for brain tumor segmentation from 3D/4D NIfTI MRI scans using a 3D U-Net model, built with PyTorch and Streamlit. You can run it locally or deploy it on [Hugging Face Spaces](https://huggingface.co/spaces).
|
| 39 |
+
|
| 40 |
+
## Features
|
| 41 |
+
- Upload four 3D NIfTI brain scans (T1, T1ce, T2, FLAIR)
|
| 42 |
+
- Automatic preprocessing and patch-based inference
|
| 43 |
+
- Visualizes the predicted tumor mask overlayed on the MRI
|
| 44 |
+
|
| 45 |
+
## Quick Start (Hugging Face Spaces)
|
| 46 |
+
|
| 47 |
+
1. **Upload your trained model file** (`unet3d_model.pth`) to the Space's root directory or use a download link in the code.
|
| 48 |
+
2. Click "Run" or "Duplicate Space" to use your own model.
|
| 49 |
+
3. Use the web interface to upload your NIfTI files and view results.
|
| 50 |
+
|
| 51 |
+
## Local Usage
|
| 52 |
+
1. Clone this repository:
|
| 53 |
+
```sh
|
| 54 |
+
git clone https://github.com/saketh-005/brain-tumor-segmentation.git
|
| 55 |
+
cd brain-tumor-segmentation
|
| 56 |
+
```
|
| 57 |
+
2. (Recommended) Create and activate a Python virtual environment:
|
| 58 |
+
```sh
|
| 59 |
+
python3 -m venv .venv
|
| 60 |
+
source .venv/bin/activate
|
| 61 |
+
```
|
| 62 |
+
3. Install dependencies:
|
| 63 |
+
```sh
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
```
|
| 66 |
+
4. Download the trained model file (`unet3d_model.pth`) and place it in this directory. (Due to file size, it is not included in the repo. Please contact the author or use your own trained model.)
|
| 67 |
+
5. Run the app:
|
| 68 |
+
```sh
|
| 69 |
+
streamlit run app.py
|
| 70 |
+
```
|
| 71 |
+
6. Open your browser to [http://localhost:8501](http://localhost:8501) and use the app.
|
| 72 |
+
|
| 73 |
+
## File Structure
|
| 74 |
+
- `app.py` - Main Streamlit app
|
| 75 |
+
- `unet_model.py` - 3D U-Net model definition
|
| 76 |
+
- `utils.py` - Preprocessing, postprocessing, and visualization utilities
|
| 77 |
+
- `requirements.txt` - Python dependencies
|
| 78 |
+
- `unet3d_model.pth` - Trained model weights (**not included**)
|
| 79 |
+
|
| 80 |
+
## Notes
|
| 81 |
+
- The model file (`unet3d_model.pth`) must be trained and exported separately.
|
| 82 |
+
- For large files, use cloud storage and provide a download link in this README or in your Hugging Face Space.
|
| 83 |
+
- For best results, ensure all input NIfTI files have the same dimensions and orientation.
|
| 84 |
+
|
| 85 |
+
## License
|
| 86 |
+
MIT License
|
| 87 |
+
|
| 88 |
+
## Author
|
| 89 |
+
[Saketh Jangala](https://github.com/saketh-005)
|
app.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Patch-based Inference Helper ---
|
| 2 |
+
def run_patch_inference(model, tensor, patch_depth=32):
|
| 3 |
+
"""
|
| 4 |
+
Run model inference on 3D tensor in patches along the depth axis.
|
| 5 |
+
Args:
|
| 6 |
+
model: The 3D segmentation model.
|
| 7 |
+
tensor: Input tensor of shape [1, 4, D, H, W].
|
| 8 |
+
patch_depth: Depth of each patch.
|
| 9 |
+
Returns:
|
| 10 |
+
Output tensor stitched together.
|
| 11 |
+
"""
|
| 12 |
+
device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
|
| 13 |
+
_, c, d, h, w = tensor.shape
|
| 14 |
+
output = []
|
| 15 |
+
for start in range(0, d, patch_depth):
|
| 16 |
+
end = min(start + patch_depth, d)
|
| 17 |
+
patch = tensor[:, :, start:end, :, :]
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
patch_out = model(patch.to(device))
|
| 20 |
+
output.append(patch_out.cpu())
|
| 21 |
+
# Concatenate along the depth axis
|
| 22 |
+
return torch.cat(output, dim=2)
|
| 23 |
+
import streamlit as st
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import nibabel as nib
|
| 28 |
+
import numpy as np
|
| 29 |
+
import os
|
| 30 |
+
import io
|
| 31 |
+
import tempfile
|
| 32 |
+
|
| 33 |
+
from utils import preprocess_nifti, postprocess_mask, visualize_prediction, combine_nifti_files
|
| 34 |
+
|
| 35 |
+
# --- Page Configuration ---
|
| 36 |
+
st.set_page_config(
|
| 37 |
+
page_title="Brain Tumor Segmentation App",
|
| 38 |
+
layout="wide",
|
| 39 |
+
initial_sidebar_state="expanded"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# --- App Title and Description ---
|
| 43 |
+
st.title("Brain Tumor Segmentation")
|
| 44 |
+
st.write("Upload the four 3D NIfTI brain scans (.nii or .nii.gz) for each modality to get a segmentation mask of the tumor.")
|
| 45 |
+
st.markdown("---")
|
| 46 |
+
|
| 47 |
+
# --- Model Architecture ---
|
| 48 |
+
# A single block in the U-Net architecture.
|
| 49 |
+
class DoubleConv(nn.Module):
|
| 50 |
+
"""(convolution => GroupNorm => ReLU) * 2"""
|
| 51 |
+
def __init__(self, in_channels, out_channels):
|
| 52 |
+
super().__init__()
|
| 53 |
+
# 3D convolutional layers, GroupNorm for stable training, and ReLU activation.
|
| 54 |
+
self.double_conv = nn.Sequential(
|
| 55 |
+
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 56 |
+
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
|
| 57 |
+
nn.ReLU(inplace=True),
|
| 58 |
+
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 59 |
+
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
|
| 60 |
+
nn.ReLU(inplace=True)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return self.double_conv(x)
|
| 65 |
+
|
| 66 |
+
# The downsampling part of the U-Net.
|
| 67 |
+
class Down(nn.Module):
|
| 68 |
+
"""Downscaling with maxpool then double conv"""
|
| 69 |
+
def __init__(self, in_channels, out_channels):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.encoder = nn.Sequential(
|
| 72 |
+
nn.MaxPool3d(2),
|
| 73 |
+
DoubleConv(in_channels, out_channels)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
return self.encoder(x)
|
| 78 |
+
|
| 79 |
+
# The upsampling part of the U-Net.
|
| 80 |
+
class Up(nn.Module):
|
| 81 |
+
"""Upscaling then double conv"""
|
| 82 |
+
def __init__(self, in_channels, out_channels):
|
| 83 |
+
super().__init__()
|
| 84 |
+
# Use bilinear upsampling and then a convolution layer
|
| 85 |
+
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 86 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 87 |
+
|
| 88 |
+
def forward(self, x1, x2):
|
| 89 |
+
x1 = self.up(x1)
|
| 90 |
+
# Pad x1 to match the size of x2 for concatenation
|
| 91 |
+
diffX = x2.size()[2] - x1.size()[2]
|
| 92 |
+
diffY = x2.size()[3] - x1.size()[3]
|
| 93 |
+
diffZ = x2.size()[4] - x1.size()[4]
|
| 94 |
+
|
| 95 |
+
x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2,
|
| 96 |
+
diffY // 2, diffY - diffY // 2,
|
| 97 |
+
diffX // 2, diffX - diffX // 2])
|
| 98 |
+
|
| 99 |
+
x = torch.cat([x2, x1], dim=1)
|
| 100 |
+
return self.conv(x)
|
| 101 |
+
|
| 102 |
+
# The final output convolutional layer.
|
| 103 |
+
class Out(nn.Module):
|
| 104 |
+
def __init__(self, in_channels, out_channels):
|
| 105 |
+
super(Out, self).__init__()
|
| 106 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
return self.conv(x)
|
| 110 |
+
|
| 111 |
+
# The complete 3D U-Net model.
|
| 112 |
+
class UNet3d(nn.Module):
|
| 113 |
+
def __init__(self, n_channels=4, n_classes=3):
|
| 114 |
+
super().__init__()
|
| 115 |
+
# The number of classes is 3 (tumor core, edema, enhancing tumor).
|
| 116 |
+
self.n_channels = n_channels
|
| 117 |
+
self.n_classes = n_classes
|
| 118 |
+
|
| 119 |
+
# Contracting path
|
| 120 |
+
self.conv = DoubleConv(n_channels, 16)
|
| 121 |
+
self.enc1 = Down(16, 32)
|
| 122 |
+
self.enc2 = Down(32, 64)
|
| 123 |
+
self.enc3 = Down(64, 128)
|
| 124 |
+
self.enc4 = Down(128, 256)
|
| 125 |
+
|
| 126 |
+
# Expansive path
|
| 127 |
+
self.dec1 = Up(256 + 128, 128)
|
| 128 |
+
self.dec2 = Up(128 + 64, 64)
|
| 129 |
+
self.dec3 = Up(64 + 32, 32)
|
| 130 |
+
self.dec4 = Up(32 + 16, 16)
|
| 131 |
+
|
| 132 |
+
self.out = Out(16, n_classes)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x1 = self.conv(x)
|
| 136 |
+
x2 = self.enc1(x1)
|
| 137 |
+
x3 = self.enc2(x2)
|
| 138 |
+
x4 = self.enc3(x3)
|
| 139 |
+
x5 = self.enc4(x4)
|
| 140 |
+
|
| 141 |
+
x = self.dec1(x5, x4)
|
| 142 |
+
x = self.dec2(x, x3)
|
| 143 |
+
x = self.dec3(x, x2)
|
| 144 |
+
x = self.dec4(x, x1)
|
| 145 |
+
|
| 146 |
+
logits = self.out(x)
|
| 147 |
+
return logits
|
| 148 |
+
|
| 149 |
+
# --- Model Loading ---
|
| 150 |
+
@st.cache_resource
|
| 151 |
+
def load_model(model_path):
|
| 152 |
+
"""Loads the trained PyTorch model from a .pth file."""
|
| 153 |
+
try:
|
| 154 |
+
# FIX: Directly load the model object, which is what was saved.
|
| 155 |
+
# The weights_only=False argument is needed for custom classes.
|
| 156 |
+
model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
|
| 157 |
+
|
| 158 |
+
model.eval()
|
| 159 |
+
st.success("Model loaded successfully!")
|
| 160 |
+
return model
|
| 161 |
+
except Exception as e:
|
| 162 |
+
st.error(f"Error loading model: {e}")
|
| 163 |
+
return None
|
| 164 |
+
|
| 165 |
+
# --- Main App Logic ---
|
| 166 |
+
model_file_path = "unet3d_model.pth"
|
| 167 |
+
if not os.path.exists(model_file_path):
|
| 168 |
+
st.warning("Model file 'unet3d_model.pth' not found. Please ensure it is in the same directory.")
|
| 169 |
+
model = None
|
| 170 |
+
else:
|
| 171 |
+
model = load_model(model_file_path)
|
| 172 |
+
|
| 173 |
+
st.sidebar.header("Upload NIfTI Files")
|
| 174 |
+
t1_file = st.sidebar.file_uploader("Choose a T1 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1")
|
| 175 |
+
t1ce_file = st.sidebar.file_uploader("Choose a T1ce scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1ce")
|
| 176 |
+
t2_file = st.sidebar.file_uploader("Choose a T2 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t2")
|
| 177 |
+
flair_file = st.sidebar.file_uploader("Choose a FLAIR scan (.nii or .nii.gz)", type=["nii", "gz"], key="flair")
|
| 178 |
+
|
| 179 |
+
if t1_file and t1ce_file and t2_file and flair_file and model is not None:
|
| 180 |
+
st.info("All files uploaded successfully. Processing...")
|
| 181 |
+
|
| 182 |
+
# temp_combined_file_path is now defined at the start of the block
|
| 183 |
+
temp_combined_file_path = None
|
| 184 |
+
|
| 185 |
+
with st.spinner("Combining NIfTI files and making prediction..."):
|
| 186 |
+
try:
|
| 187 |
+
# Create temporary files for each uploaded file
|
| 188 |
+
with tempfile.NamedTemporaryFile(suffix=f"_{t1_file.name}") as t1_temp, \
|
| 189 |
+
tempfile.NamedTemporaryFile(suffix=f"_{t1ce_file.name}") as t1ce_temp, \
|
| 190 |
+
tempfile.NamedTemporaryFile(suffix=f"_{t2_file.name}") as t2_temp, \
|
| 191 |
+
tempfile.NamedTemporaryFile(suffix=f"_{flair_file.name}") as flair_temp:
|
| 192 |
+
|
| 193 |
+
t1_temp.write(t1_file.getvalue())
|
| 194 |
+
t1ce_temp.write(t1ce_file.getvalue())
|
| 195 |
+
t2_temp.write(t2_file.getvalue())
|
| 196 |
+
flair_temp.write(flair_file.getvalue())
|
| 197 |
+
|
| 198 |
+
# Pass the temporary file paths to the combine function
|
| 199 |
+
combined_nifti_img = combine_nifti_files(t1_temp.name, t1ce_temp.name, t2_temp.name, flair_temp.name)
|
| 200 |
+
|
| 201 |
+
original_data = combined_nifti_img.get_fdata()
|
| 202 |
+
|
| 203 |
+
# Preprocess the combined image
|
| 204 |
+
# We need to save the combined NIfTI object to a file for nibabel to load it properly
|
| 205 |
+
temp_combined_file_path = "combined_4d.nii.gz"
|
| 206 |
+
nib.save(combined_nifti_img, temp_combined_file_path)
|
| 207 |
+
|
| 208 |
+
_, processed_tensor = preprocess_nifti(temp_combined_file_path)
|
| 209 |
+
|
| 210 |
+
if original_data is not None and processed_tensor is not None:
|
| 211 |
+
st.success("Preprocessing complete!")
|
| 212 |
+
|
| 213 |
+
# --- Patch-based Model Prediction ---
|
| 214 |
+
st.info("Running patch-based model inference...")
|
| 215 |
+
try:
|
| 216 |
+
prediction_tensor = run_patch_inference(model, processed_tensor, patch_depth=32)
|
| 217 |
+
st.success("Prediction complete!")
|
| 218 |
+
except Exception as e:
|
| 219 |
+
st.error(f"Error during patch-based inference: {e}")
|
| 220 |
+
raise
|
| 221 |
+
|
| 222 |
+
# Post-process the prediction to get a mask, resizing back to original size
|
| 223 |
+
predicted_mask = postprocess_mask(prediction_tensor, original_data.shape)
|
| 224 |
+
|
| 225 |
+
if predicted_mask is not None:
|
| 226 |
+
st.header("Results")
|
| 227 |
+
# Ensure mask is int and shape matches for visualization
|
| 228 |
+
max_slices = original_data.shape[2]
|
| 229 |
+
slice_index = st.slider("Select an axial slice to view", 0, max_slices - 1, max_slices // 2)
|
| 230 |
+
fig = visualize_prediction(original_data, predicted_mask.astype(int), slice_index=slice_index)
|
| 231 |
+
st.pyplot(fig)
|
| 232 |
+
else:
|
| 233 |
+
st.error("Could not post-process the model's prediction.")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
st.error(f"An error occurred during processing: {e}")
|
| 237 |
+
st.error("Please ensure the uploaded files are valid NIfTI files with the same dimensions.")
|
| 238 |
+
finally:
|
| 239 |
+
# Clean up temporary files
|
| 240 |
+
if os.path.exists(temp_combined_file_path):
|
| 241 |
+
os.remove(temp_combined_file_path)
|
| 242 |
+
|
| 243 |
+
# --- Footer ---
|
| 244 |
+
st.markdown("---")
|
| 245 |
+
st.markdown("Developed with PyTorch and Streamlit.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
numpy
|
| 5 |
+
matplotlib
|
| 6 |
+
pandas
|
| 7 |
+
nibabel
|
| 8 |
+
Pillow
|
| 9 |
+
scikit-image
|
unet3d_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:24b57e013f9f929e2a2c39af499ec86cf6dae2d42eb4481a0ccd4e2b94a984f7
|
| 3 |
+
size 22655363
|
unet_model.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class DoubleConv(nn.Module):
|
| 6 |
+
"""(convolution => GroupNorm => ReLU) * 2"""
|
| 7 |
+
def __init__(self, in_channels, out_channels):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.double_conv = nn.Sequential(
|
| 10 |
+
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 11 |
+
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
|
| 12 |
+
nn.ReLU(inplace=True),
|
| 13 |
+
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 14 |
+
nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
|
| 15 |
+
nn.ReLU(inplace=True)
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.double_conv(x)
|
| 20 |
+
|
| 21 |
+
class Down(nn.Module):
|
| 22 |
+
"""Downscaling with maxpool then double conv"""
|
| 23 |
+
def __init__(self, in_channels, out_channels):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.encoder = nn.Sequential(
|
| 26 |
+
nn.MaxPool3d(2),
|
| 27 |
+
DoubleConv(in_channels, out_channels)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
return self.encoder(x)
|
| 32 |
+
|
| 33 |
+
class Up(nn.Module):
|
| 34 |
+
"""Upscaling then double conv"""
|
| 35 |
+
def __init__(self, in_channels, out_channels):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
|
| 38 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 39 |
+
|
| 40 |
+
def forward(self, x1, x2):
|
| 41 |
+
x1 = self.up(x1)
|
| 42 |
+
diffX = x2.size()[2] - x1.size()[2]
|
| 43 |
+
diffY = x2.size()[3] - x1.size()[3]
|
| 44 |
+
diffZ = x2.size()[4] - x1.size()[4]
|
| 45 |
+
|
| 46 |
+
x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2,
|
| 47 |
+
diffY // 2, diffY - diffY // 2,
|
| 48 |
+
diffX // 2, diffX - diffX // 2])
|
| 49 |
+
|
| 50 |
+
x = torch.cat([x2, x1], dim=1)
|
| 51 |
+
return self.conv(x)
|
| 52 |
+
|
| 53 |
+
class OutConv(nn.Module):
|
| 54 |
+
def __init__(self, in_channels, out_channels):
|
| 55 |
+
super(OutConv, self).__init__()
|
| 56 |
+
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
return self.conv(x)
|
| 60 |
+
|
| 61 |
+
class UNet3d(nn.Module):
|
| 62 |
+
def __init__(self, n_channels=4, n_classes=3):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.n_channels = n_channels
|
| 65 |
+
self.n_classes = n_classes
|
| 66 |
+
|
| 67 |
+
# Contracting path
|
| 68 |
+
self.conv = DoubleConv(n_channels, 16)
|
| 69 |
+
self.enc1 = Down(16, 32)
|
| 70 |
+
self.enc2 = Down(32, 64)
|
| 71 |
+
self.enc3 = Down(64, 128)
|
| 72 |
+
self.enc4 = Down(128, 256)
|
| 73 |
+
|
| 74 |
+
# Expansive path
|
| 75 |
+
self.dec1 = Up(256 + 128, 128)
|
| 76 |
+
self.dec2 = Up(128 + 64, 64)
|
| 77 |
+
self.dec3 = Up(64 + 32, 32)
|
| 78 |
+
self.dec4 = Up(32 + 16, 16)
|
| 79 |
+
|
| 80 |
+
self.out = OutConv(16, n_classes)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
x1 = self.conv(x)
|
| 84 |
+
x2 = self.enc1(x1)
|
| 85 |
+
x3 = self.enc2(x2)
|
| 86 |
+
x4 = self.enc3(x3)
|
| 87 |
+
x5 = self.enc4(x4)
|
| 88 |
+
|
| 89 |
+
x = self.dec1(x5, x4)
|
| 90 |
+
x = self.dec2(x, x3)
|
| 91 |
+
x = self.dec3(x, x2)
|
| 92 |
+
x = self.dec4(x, x1)
|
| 93 |
+
|
| 94 |
+
logits = self.out(x)
|
| 95 |
+
return logits
|
utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nibabel as nib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import matplotlib.patches as mpatches
|
| 7 |
+
import streamlit as st
|
| 8 |
+
import io
|
| 9 |
+
import tempfile
|
| 10 |
+
from skimage.transform import resize
|
| 11 |
+
|
| 12 |
+
def preprocess_nifti(nifti_file):
|
| 13 |
+
"""
|
| 14 |
+
Loads a NIfTI file, preprocesses it, and returns a PyTorch tensor.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
nifti_file (str or io.BytesIO): Path to the NIfTI file or a file-like object.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
tuple: A tuple containing the original image data and a preprocessed tensor.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
nifti_img = nib.load(nifti_file)
|
| 24 |
+
img_data = nifti_img.get_fdata()
|
| 25 |
+
|
| 26 |
+
if len(img_data.shape) != 4 or img_data.shape[-1] != 4:
|
| 27 |
+
st.error("The uploaded NIfTI file must be a 4D image with 4 channels.")
|
| 28 |
+
return None, None
|
| 29 |
+
|
| 30 |
+
for i in range(img_data.shape[-1]):
|
| 31 |
+
channel_data = img_data[..., i]
|
| 32 |
+
if np.max(channel_data) > 0:
|
| 33 |
+
img_data[..., i] = channel_data / np.max(channel_data)
|
| 34 |
+
|
| 35 |
+
img_data = np.transpose(img_data, (3, 2, 0, 1))
|
| 36 |
+
|
| 37 |
+
tensor_data = torch.from_numpy(img_data).float()
|
| 38 |
+
tensor_data = torch.unsqueeze(tensor_data, 0)
|
| 39 |
+
|
| 40 |
+
return nifti_img.get_fdata(), tensor_data
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
st.error(f"Error during NIfTI preprocessing: {e}")
|
| 44 |
+
return None, None
|
| 45 |
+
|
| 46 |
+
def postprocess_mask(prediction_tensor, original_shape):
|
| 47 |
+
"""
|
| 48 |
+
Converts model output (tensor) into a visualizable mask (numpy array)
|
| 49 |
+
and resizes it to the original image dimensions.
|
| 50 |
+
"""
|
| 51 |
+
try:
|
| 52 |
+
probabilities = F.softmax(prediction_tensor, dim=1)
|
| 53 |
+
|
| 54 |
+
mask = torch.argmax(probabilities, dim=1)
|
| 55 |
+
|
| 56 |
+
mask = mask.detach().cpu().numpy()
|
| 57 |
+
mask = np.squeeze(mask)
|
| 58 |
+
|
| 59 |
+
mask = np.transpose(mask, (1, 2, 0))
|
| 60 |
+
|
| 61 |
+
resized_mask = resize(mask, original_shape[:3], order=0, preserve_range=True, anti_aliasing=False)
|
| 62 |
+
|
| 63 |
+
return resized_mask
|
| 64 |
+
except Exception as e:
|
| 65 |
+
st.error(f"Error during mask post-processing: {e}")
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
def visualize_prediction(original_image, predicted_mask, slice_index=75):
|
| 69 |
+
"""
|
| 70 |
+
Creates a 2-panel visualization of the original image and the predicted mask.
|
| 71 |
+
"""
|
| 72 |
+
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
|
| 73 |
+
# Show FLAIR channel for the original image
|
| 74 |
+
axes[0].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
|
| 75 |
+
axes[0].set_title('Original Image (FLAIR)', fontsize=16)
|
| 76 |
+
axes[0].axis('off')
|
| 77 |
+
|
| 78 |
+
axes[1].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
|
| 79 |
+
# Overlay the predicted mask directly
|
| 80 |
+
mask_slice = np.rot90(predicted_mask[:, :, slice_index])
|
| 81 |
+
axes[1].imshow(np.ma.masked_where(mask_slice == 0, mask_slice), cmap='jet', alpha=0.5)
|
| 82 |
+
axes[1].set_title('Predicted Tumor Mask', fontsize=16)
|
| 83 |
+
axes[1].axis('off')
|
| 84 |
+
return fig
|
| 85 |
+
|
| 86 |
+
def combine_nifti_files(t1_file_path, t1ce_file_path, t2_file_path, flair_file_path):
|
| 87 |
+
"""
|
| 88 |
+
Combines four 3D NIfTI files from given paths into a single 4D NIfTI file object.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
t1_file_path, t1ce_file_path, t2_file_path, flair_file_path (str): Paths to the temporary NIfTI files.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
nib.Nifti1Image: A 4D NIfTI image object.
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
# Load the four 3D NIfTI files from file paths
|
| 98 |
+
t1_img = nib.load(t1_file_path)
|
| 99 |
+
t1ce_img = nib.load(t1ce_file_path)
|
| 100 |
+
t2_img = nib.load(t2_file_path)
|
| 101 |
+
flair_img = nib.load(flair_file_path)
|
| 102 |
+
|
| 103 |
+
# Get the image data as NumPy arrays
|
| 104 |
+
t1_data = t1_img.get_fdata()
|
| 105 |
+
t1ce_data = t1ce_img.get_fdata()
|
| 106 |
+
t2_data = t2_img.get_fdata()
|
| 107 |
+
flair_data = flair_img.get_fdata()
|
| 108 |
+
|
| 109 |
+
# Ensure all files have the same shape
|
| 110 |
+
if not (t1_data.shape == t1ce_data.shape == t2_data.shape == flair_data.shape):
|
| 111 |
+
st.error("Error: Input NIfTI files do not have matching dimensions.")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
# Stack the 3D arrays along a new (4th) dimension to create a 4D array
|
| 115 |
+
combined_data = np.stack([t1_data, t1ce_data, t2_data, flair_data], axis=-1)
|
| 116 |
+
|
| 117 |
+
# Create a new 4D NIfTI image object
|
| 118 |
+
combined_img = nib.Nifti1Image(combined_data, t1_img.affine)
|
| 119 |
+
|
| 120 |
+
return combined_img
|
| 121 |
+
except Exception as e:
|
| 122 |
+
st.error(f"Error combining NIfTI files: {e}")
|
| 123 |
+
return None
|