JANGALA SAKETH commited on
Commit
fa7fb3e
·
verified ·
1 Parent(s): c5562a1

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +89 -0
  3. app.py +245 -0
  4. requirements.txt +9 -0
  5. unet3d_model.pth +3 -0
  6. unet_model.py +95 -0
  7. utils.py +123 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ unet3d_model.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Brain Tumor Segmentation
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ sdk_version: 1.48.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ ## Sample Data
14
+
15
+ You can download sample NIfTI files for two patients to test the model from this Google Drive link:
16
+
17
+ [Sample Data (Google Drive)](https://drive.google.com/drive/folders/19LzKOcoIrWQhwY91e_kn644AcQi4tl8z?usp=sharing)
18
+
19
+ ---
20
+ title: Brain Tumor Segmentation
21
+ emoji: 🧠
22
+ colorFrom: blue
23
+ colorTo: pink
24
+ sdk: streamlit
25
+ sdk_version: 1.48.1
26
+ app_file: app.py
27
+ pinned: false
28
+ license: mit
29
+ ---
30
+
31
+ # Brain Tumor Segmentation App
32
+
33
+ <p align="center">
34
+ <img src="https://img.shields.io/badge/Streamlit-Online-brightgreen" alt="Streamlit">
35
+ <a href="https://huggingface.co/spaces/saketh-005/brain-tumor-segmentation"><img src="https://img.shields.io/badge/HuggingFace-Spaces-yellow" alt="Hugging Face Spaces"></a>
36
+ </p>
37
+
38
+ This project is a web application for brain tumor segmentation from 3D/4D NIfTI MRI scans using a 3D U-Net model, built with PyTorch and Streamlit. You can run it locally or deploy it on [Hugging Face Spaces](https://huggingface.co/spaces).
39
+
40
+ ## Features
41
+ - Upload four 3D NIfTI brain scans (T1, T1ce, T2, FLAIR)
42
+ - Automatic preprocessing and patch-based inference
43
+ - Visualizes the predicted tumor mask overlayed on the MRI
44
+
45
+ ## Quick Start (Hugging Face Spaces)
46
+
47
+ 1. **Upload your trained model file** (`unet3d_model.pth`) to the Space's root directory or use a download link in the code.
48
+ 2. Click "Run" or "Duplicate Space" to use your own model.
49
+ 3. Use the web interface to upload your NIfTI files and view results.
50
+
51
+ ## Local Usage
52
+ 1. Clone this repository:
53
+ ```sh
54
+ git clone https://github.com/saketh-005/brain-tumor-segmentation.git
55
+ cd brain-tumor-segmentation
56
+ ```
57
+ 2. (Recommended) Create and activate a Python virtual environment:
58
+ ```sh
59
+ python3 -m venv .venv
60
+ source .venv/bin/activate
61
+ ```
62
+ 3. Install dependencies:
63
+ ```sh
64
+ pip install -r requirements.txt
65
+ ```
66
+ 4. Download the trained model file (`unet3d_model.pth`) and place it in this directory. (Due to file size, it is not included in the repo. Please contact the author or use your own trained model.)
67
+ 5. Run the app:
68
+ ```sh
69
+ streamlit run app.py
70
+ ```
71
+ 6. Open your browser to [http://localhost:8501](http://localhost:8501) and use the app.
72
+
73
+ ## File Structure
74
+ - `app.py` - Main Streamlit app
75
+ - `unet_model.py` - 3D U-Net model definition
76
+ - `utils.py` - Preprocessing, postprocessing, and visualization utilities
77
+ - `requirements.txt` - Python dependencies
78
+ - `unet3d_model.pth` - Trained model weights (**not included**)
79
+
80
+ ## Notes
81
+ - The model file (`unet3d_model.pth`) must be trained and exported separately.
82
+ - For large files, use cloud storage and provide a download link in this README or in your Hugging Face Space.
83
+ - For best results, ensure all input NIfTI files have the same dimensions and orientation.
84
+
85
+ ## License
86
+ MIT License
87
+
88
+ ## Author
89
+ [Saketh Jangala](https://github.com/saketh-005)
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Patch-based Inference Helper ---
2
+ def run_patch_inference(model, tensor, patch_depth=32):
3
+ """
4
+ Run model inference on 3D tensor in patches along the depth axis.
5
+ Args:
6
+ model: The 3D segmentation model.
7
+ tensor: Input tensor of shape [1, 4, D, H, W].
8
+ patch_depth: Depth of each patch.
9
+ Returns:
10
+ Output tensor stitched together.
11
+ """
12
+ device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
13
+ _, c, d, h, w = tensor.shape
14
+ output = []
15
+ for start in range(0, d, patch_depth):
16
+ end = min(start + patch_depth, d)
17
+ patch = tensor[:, :, start:end, :, :]
18
+ with torch.no_grad():
19
+ patch_out = model(patch.to(device))
20
+ output.append(patch_out.cpu())
21
+ # Concatenate along the depth axis
22
+ return torch.cat(output, dim=2)
23
+ import streamlit as st
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ import nibabel as nib
28
+ import numpy as np
29
+ import os
30
+ import io
31
+ import tempfile
32
+
33
+ from utils import preprocess_nifti, postprocess_mask, visualize_prediction, combine_nifti_files
34
+
35
+ # --- Page Configuration ---
36
+ st.set_page_config(
37
+ page_title="Brain Tumor Segmentation App",
38
+ layout="wide",
39
+ initial_sidebar_state="expanded"
40
+ )
41
+
42
+ # --- App Title and Description ---
43
+ st.title("Brain Tumor Segmentation")
44
+ st.write("Upload the four 3D NIfTI brain scans (.nii or .nii.gz) for each modality to get a segmentation mask of the tumor.")
45
+ st.markdown("---")
46
+
47
+ # --- Model Architecture ---
48
+ # A single block in the U-Net architecture.
49
+ class DoubleConv(nn.Module):
50
+ """(convolution => GroupNorm => ReLU) * 2"""
51
+ def __init__(self, in_channels, out_channels):
52
+ super().__init__()
53
+ # 3D convolutional layers, GroupNorm for stable training, and ReLU activation.
54
+ self.double_conv = nn.Sequential(
55
+ nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
56
+ nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
57
+ nn.ReLU(inplace=True),
58
+ nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
59
+ nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
60
+ nn.ReLU(inplace=True)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.double_conv(x)
65
+
66
+ # The downsampling part of the U-Net.
67
+ class Down(nn.Module):
68
+ """Downscaling with maxpool then double conv"""
69
+ def __init__(self, in_channels, out_channels):
70
+ super().__init__()
71
+ self.encoder = nn.Sequential(
72
+ nn.MaxPool3d(2),
73
+ DoubleConv(in_channels, out_channels)
74
+ )
75
+
76
+ def forward(self, x):
77
+ return self.encoder(x)
78
+
79
+ # The upsampling part of the U-Net.
80
+ class Up(nn.Module):
81
+ """Upscaling then double conv"""
82
+ def __init__(self, in_channels, out_channels):
83
+ super().__init__()
84
+ # Use bilinear upsampling and then a convolution layer
85
+ self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
86
+ self.conv = DoubleConv(in_channels, out_channels)
87
+
88
+ def forward(self, x1, x2):
89
+ x1 = self.up(x1)
90
+ # Pad x1 to match the size of x2 for concatenation
91
+ diffX = x2.size()[2] - x1.size()[2]
92
+ diffY = x2.size()[3] - x1.size()[3]
93
+ diffZ = x2.size()[4] - x1.size()[4]
94
+
95
+ x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2,
96
+ diffY // 2, diffY - diffY // 2,
97
+ diffX // 2, diffX - diffX // 2])
98
+
99
+ x = torch.cat([x2, x1], dim=1)
100
+ return self.conv(x)
101
+
102
+ # The final output convolutional layer.
103
+ class Out(nn.Module):
104
+ def __init__(self, in_channels, out_channels):
105
+ super(Out, self).__init__()
106
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
107
+
108
+ def forward(self, x):
109
+ return self.conv(x)
110
+
111
+ # The complete 3D U-Net model.
112
+ class UNet3d(nn.Module):
113
+ def __init__(self, n_channels=4, n_classes=3):
114
+ super().__init__()
115
+ # The number of classes is 3 (tumor core, edema, enhancing tumor).
116
+ self.n_channels = n_channels
117
+ self.n_classes = n_classes
118
+
119
+ # Contracting path
120
+ self.conv = DoubleConv(n_channels, 16)
121
+ self.enc1 = Down(16, 32)
122
+ self.enc2 = Down(32, 64)
123
+ self.enc3 = Down(64, 128)
124
+ self.enc4 = Down(128, 256)
125
+
126
+ # Expansive path
127
+ self.dec1 = Up(256 + 128, 128)
128
+ self.dec2 = Up(128 + 64, 64)
129
+ self.dec3 = Up(64 + 32, 32)
130
+ self.dec4 = Up(32 + 16, 16)
131
+
132
+ self.out = Out(16, n_classes)
133
+
134
+ def forward(self, x):
135
+ x1 = self.conv(x)
136
+ x2 = self.enc1(x1)
137
+ x3 = self.enc2(x2)
138
+ x4 = self.enc3(x3)
139
+ x5 = self.enc4(x4)
140
+
141
+ x = self.dec1(x5, x4)
142
+ x = self.dec2(x, x3)
143
+ x = self.dec3(x, x2)
144
+ x = self.dec4(x, x1)
145
+
146
+ logits = self.out(x)
147
+ return logits
148
+
149
+ # --- Model Loading ---
150
+ @st.cache_resource
151
+ def load_model(model_path):
152
+ """Loads the trained PyTorch model from a .pth file."""
153
+ try:
154
+ # FIX: Directly load the model object, which is what was saved.
155
+ # The weights_only=False argument is needed for custom classes.
156
+ model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
157
+
158
+ model.eval()
159
+ st.success("Model loaded successfully!")
160
+ return model
161
+ except Exception as e:
162
+ st.error(f"Error loading model: {e}")
163
+ return None
164
+
165
+ # --- Main App Logic ---
166
+ model_file_path = "unet3d_model.pth"
167
+ if not os.path.exists(model_file_path):
168
+ st.warning("Model file 'unet3d_model.pth' not found. Please ensure it is in the same directory.")
169
+ model = None
170
+ else:
171
+ model = load_model(model_file_path)
172
+
173
+ st.sidebar.header("Upload NIfTI Files")
174
+ t1_file = st.sidebar.file_uploader("Choose a T1 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1")
175
+ t1ce_file = st.sidebar.file_uploader("Choose a T1ce scan (.nii or .nii.gz)", type=["nii", "gz"], key="t1ce")
176
+ t2_file = st.sidebar.file_uploader("Choose a T2 scan (.nii or .nii.gz)", type=["nii", "gz"], key="t2")
177
+ flair_file = st.sidebar.file_uploader("Choose a FLAIR scan (.nii or .nii.gz)", type=["nii", "gz"], key="flair")
178
+
179
+ if t1_file and t1ce_file and t2_file and flair_file and model is not None:
180
+ st.info("All files uploaded successfully. Processing...")
181
+
182
+ # temp_combined_file_path is now defined at the start of the block
183
+ temp_combined_file_path = None
184
+
185
+ with st.spinner("Combining NIfTI files and making prediction..."):
186
+ try:
187
+ # Create temporary files for each uploaded file
188
+ with tempfile.NamedTemporaryFile(suffix=f"_{t1_file.name}") as t1_temp, \
189
+ tempfile.NamedTemporaryFile(suffix=f"_{t1ce_file.name}") as t1ce_temp, \
190
+ tempfile.NamedTemporaryFile(suffix=f"_{t2_file.name}") as t2_temp, \
191
+ tempfile.NamedTemporaryFile(suffix=f"_{flair_file.name}") as flair_temp:
192
+
193
+ t1_temp.write(t1_file.getvalue())
194
+ t1ce_temp.write(t1ce_file.getvalue())
195
+ t2_temp.write(t2_file.getvalue())
196
+ flair_temp.write(flair_file.getvalue())
197
+
198
+ # Pass the temporary file paths to the combine function
199
+ combined_nifti_img = combine_nifti_files(t1_temp.name, t1ce_temp.name, t2_temp.name, flair_temp.name)
200
+
201
+ original_data = combined_nifti_img.get_fdata()
202
+
203
+ # Preprocess the combined image
204
+ # We need to save the combined NIfTI object to a file for nibabel to load it properly
205
+ temp_combined_file_path = "combined_4d.nii.gz"
206
+ nib.save(combined_nifti_img, temp_combined_file_path)
207
+
208
+ _, processed_tensor = preprocess_nifti(temp_combined_file_path)
209
+
210
+ if original_data is not None and processed_tensor is not None:
211
+ st.success("Preprocessing complete!")
212
+
213
+ # --- Patch-based Model Prediction ---
214
+ st.info("Running patch-based model inference...")
215
+ try:
216
+ prediction_tensor = run_patch_inference(model, processed_tensor, patch_depth=32)
217
+ st.success("Prediction complete!")
218
+ except Exception as e:
219
+ st.error(f"Error during patch-based inference: {e}")
220
+ raise
221
+
222
+ # Post-process the prediction to get a mask, resizing back to original size
223
+ predicted_mask = postprocess_mask(prediction_tensor, original_data.shape)
224
+
225
+ if predicted_mask is not None:
226
+ st.header("Results")
227
+ # Ensure mask is int and shape matches for visualization
228
+ max_slices = original_data.shape[2]
229
+ slice_index = st.slider("Select an axial slice to view", 0, max_slices - 1, max_slices // 2)
230
+ fig = visualize_prediction(original_data, predicted_mask.astype(int), slice_index=slice_index)
231
+ st.pyplot(fig)
232
+ else:
233
+ st.error("Could not post-process the model's prediction.")
234
+
235
+ except Exception as e:
236
+ st.error(f"An error occurred during processing: {e}")
237
+ st.error("Please ensure the uploaded files are valid NIfTI files with the same dimensions.")
238
+ finally:
239
+ # Clean up temporary files
240
+ if os.path.exists(temp_combined_file_path):
241
+ os.remove(temp_combined_file_path)
242
+
243
+ # --- Footer ---
244
+ st.markdown("---")
245
+ st.markdown("Developed with PyTorch and Streamlit.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ numpy
5
+ matplotlib
6
+ pandas
7
+ nibabel
8
+ Pillow
9
+ scikit-image
unet3d_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24b57e013f9f929e2a2c39af499ec86cf6dae2d42eb4481a0ccd4e2b94a984f7
3
+ size 22655363
unet_model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class DoubleConv(nn.Module):
6
+ """(convolution => GroupNorm => ReLU) * 2"""
7
+ def __init__(self, in_channels, out_channels):
8
+ super().__init__()
9
+ self.double_conv = nn.Sequential(
10
+ nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
11
+ nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
14
+ nn.GroupNorm(num_groups=out_channels // 2, num_channels=out_channels),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.double_conv(x)
20
+
21
+ class Down(nn.Module):
22
+ """Downscaling with maxpool then double conv"""
23
+ def __init__(self, in_channels, out_channels):
24
+ super().__init__()
25
+ self.encoder = nn.Sequential(
26
+ nn.MaxPool3d(2),
27
+ DoubleConv(in_channels, out_channels)
28
+ )
29
+
30
+ def forward(self, x):
31
+ return self.encoder(x)
32
+
33
+ class Up(nn.Module):
34
+ """Upscaling then double conv"""
35
+ def __init__(self, in_channels, out_channels):
36
+ super().__init__()
37
+ self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
38
+ self.conv = DoubleConv(in_channels, out_channels)
39
+
40
+ def forward(self, x1, x2):
41
+ x1 = self.up(x1)
42
+ diffX = x2.size()[2] - x1.size()[2]
43
+ diffY = x2.size()[3] - x1.size()[3]
44
+ diffZ = x2.size()[4] - x1.size()[4]
45
+
46
+ x1 = F.pad(x1, [diffZ // 2, diffZ - diffZ // 2,
47
+ diffY // 2, diffY - diffY // 2,
48
+ diffX // 2, diffX - diffX // 2])
49
+
50
+ x = torch.cat([x2, x1], dim=1)
51
+ return self.conv(x)
52
+
53
+ class OutConv(nn.Module):
54
+ def __init__(self, in_channels, out_channels):
55
+ super(OutConv, self).__init__()
56
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)
57
+
58
+ def forward(self, x):
59
+ return self.conv(x)
60
+
61
+ class UNet3d(nn.Module):
62
+ def __init__(self, n_channels=4, n_classes=3):
63
+ super().__init__()
64
+ self.n_channels = n_channels
65
+ self.n_classes = n_classes
66
+
67
+ # Contracting path
68
+ self.conv = DoubleConv(n_channels, 16)
69
+ self.enc1 = Down(16, 32)
70
+ self.enc2 = Down(32, 64)
71
+ self.enc3 = Down(64, 128)
72
+ self.enc4 = Down(128, 256)
73
+
74
+ # Expansive path
75
+ self.dec1 = Up(256 + 128, 128)
76
+ self.dec2 = Up(128 + 64, 64)
77
+ self.dec3 = Up(64 + 32, 32)
78
+ self.dec4 = Up(32 + 16, 16)
79
+
80
+ self.out = OutConv(16, n_classes)
81
+
82
+ def forward(self, x):
83
+ x1 = self.conv(x)
84
+ x2 = self.enc1(x1)
85
+ x3 = self.enc2(x2)
86
+ x4 = self.enc3(x3)
87
+ x5 = self.enc4(x4)
88
+
89
+ x = self.dec1(x5, x4)
90
+ x = self.dec2(x, x3)
91
+ x = self.dec3(x, x2)
92
+ x = self.dec4(x, x1)
93
+
94
+ logits = self.out(x)
95
+ return logits
utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nibabel as nib
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as mpatches
7
+ import streamlit as st
8
+ import io
9
+ import tempfile
10
+ from skimage.transform import resize
11
+
12
+ def preprocess_nifti(nifti_file):
13
+ """
14
+ Loads a NIfTI file, preprocesses it, and returns a PyTorch tensor.
15
+
16
+ Args:
17
+ nifti_file (str or io.BytesIO): Path to the NIfTI file or a file-like object.
18
+
19
+ Returns:
20
+ tuple: A tuple containing the original image data and a preprocessed tensor.
21
+ """
22
+ try:
23
+ nifti_img = nib.load(nifti_file)
24
+ img_data = nifti_img.get_fdata()
25
+
26
+ if len(img_data.shape) != 4 or img_data.shape[-1] != 4:
27
+ st.error("The uploaded NIfTI file must be a 4D image with 4 channels.")
28
+ return None, None
29
+
30
+ for i in range(img_data.shape[-1]):
31
+ channel_data = img_data[..., i]
32
+ if np.max(channel_data) > 0:
33
+ img_data[..., i] = channel_data / np.max(channel_data)
34
+
35
+ img_data = np.transpose(img_data, (3, 2, 0, 1))
36
+
37
+ tensor_data = torch.from_numpy(img_data).float()
38
+ tensor_data = torch.unsqueeze(tensor_data, 0)
39
+
40
+ return nifti_img.get_fdata(), tensor_data
41
+
42
+ except Exception as e:
43
+ st.error(f"Error during NIfTI preprocessing: {e}")
44
+ return None, None
45
+
46
+ def postprocess_mask(prediction_tensor, original_shape):
47
+ """
48
+ Converts model output (tensor) into a visualizable mask (numpy array)
49
+ and resizes it to the original image dimensions.
50
+ """
51
+ try:
52
+ probabilities = F.softmax(prediction_tensor, dim=1)
53
+
54
+ mask = torch.argmax(probabilities, dim=1)
55
+
56
+ mask = mask.detach().cpu().numpy()
57
+ mask = np.squeeze(mask)
58
+
59
+ mask = np.transpose(mask, (1, 2, 0))
60
+
61
+ resized_mask = resize(mask, original_shape[:3], order=0, preserve_range=True, anti_aliasing=False)
62
+
63
+ return resized_mask
64
+ except Exception as e:
65
+ st.error(f"Error during mask post-processing: {e}")
66
+ return None
67
+
68
+ def visualize_prediction(original_image, predicted_mask, slice_index=75):
69
+ """
70
+ Creates a 2-panel visualization of the original image and the predicted mask.
71
+ """
72
+ fig, axes = plt.subplots(1, 2, figsize=(15, 7))
73
+ # Show FLAIR channel for the original image
74
+ axes[0].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
75
+ axes[0].set_title('Original Image (FLAIR)', fontsize=16)
76
+ axes[0].axis('off')
77
+
78
+ axes[1].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
79
+ # Overlay the predicted mask directly
80
+ mask_slice = np.rot90(predicted_mask[:, :, slice_index])
81
+ axes[1].imshow(np.ma.masked_where(mask_slice == 0, mask_slice), cmap='jet', alpha=0.5)
82
+ axes[1].set_title('Predicted Tumor Mask', fontsize=16)
83
+ axes[1].axis('off')
84
+ return fig
85
+
86
+ def combine_nifti_files(t1_file_path, t1ce_file_path, t2_file_path, flair_file_path):
87
+ """
88
+ Combines four 3D NIfTI files from given paths into a single 4D NIfTI file object.
89
+
90
+ Args:
91
+ t1_file_path, t1ce_file_path, t2_file_path, flair_file_path (str): Paths to the temporary NIfTI files.
92
+
93
+ Returns:
94
+ nib.Nifti1Image: A 4D NIfTI image object.
95
+ """
96
+ try:
97
+ # Load the four 3D NIfTI files from file paths
98
+ t1_img = nib.load(t1_file_path)
99
+ t1ce_img = nib.load(t1ce_file_path)
100
+ t2_img = nib.load(t2_file_path)
101
+ flair_img = nib.load(flair_file_path)
102
+
103
+ # Get the image data as NumPy arrays
104
+ t1_data = t1_img.get_fdata()
105
+ t1ce_data = t1ce_img.get_fdata()
106
+ t2_data = t2_img.get_fdata()
107
+ flair_data = flair_img.get_fdata()
108
+
109
+ # Ensure all files have the same shape
110
+ if not (t1_data.shape == t1ce_data.shape == t2_data.shape == flair_data.shape):
111
+ st.error("Error: Input NIfTI files do not have matching dimensions.")
112
+ return None
113
+
114
+ # Stack the 3D arrays along a new (4th) dimension to create a 4D array
115
+ combined_data = np.stack([t1_data, t1ce_data, t2_data, flair_data], axis=-1)
116
+
117
+ # Create a new 4D NIfTI image object
118
+ combined_img = nib.Nifti1Image(combined_data, t1_img.affine)
119
+
120
+ return combined_img
121
+ except Exception as e:
122
+ st.error(f"Error combining NIfTI files: {e}")
123
+ return None