saketh-005's picture
Upload 6 files
fa7fb3e verified
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import streamlit as st
import io
import tempfile
from skimage.transform import resize
def preprocess_nifti(nifti_file):
"""
Loads a NIfTI file, preprocesses it, and returns a PyTorch tensor.
Args:
nifti_file (str or io.BytesIO): Path to the NIfTI file or a file-like object.
Returns:
tuple: A tuple containing the original image data and a preprocessed tensor.
"""
try:
nifti_img = nib.load(nifti_file)
img_data = nifti_img.get_fdata()
if len(img_data.shape) != 4 or img_data.shape[-1] != 4:
st.error("The uploaded NIfTI file must be a 4D image with 4 channels.")
return None, None
for i in range(img_data.shape[-1]):
channel_data = img_data[..., i]
if np.max(channel_data) > 0:
img_data[..., i] = channel_data / np.max(channel_data)
img_data = np.transpose(img_data, (3, 2, 0, 1))
tensor_data = torch.from_numpy(img_data).float()
tensor_data = torch.unsqueeze(tensor_data, 0)
return nifti_img.get_fdata(), tensor_data
except Exception as e:
st.error(f"Error during NIfTI preprocessing: {e}")
return None, None
def postprocess_mask(prediction_tensor, original_shape):
"""
Converts model output (tensor) into a visualizable mask (numpy array)
and resizes it to the original image dimensions.
"""
try:
probabilities = F.softmax(prediction_tensor, dim=1)
mask = torch.argmax(probabilities, dim=1)
mask = mask.detach().cpu().numpy()
mask = np.squeeze(mask)
mask = np.transpose(mask, (1, 2, 0))
resized_mask = resize(mask, original_shape[:3], order=0, preserve_range=True, anti_aliasing=False)
return resized_mask
except Exception as e:
st.error(f"Error during mask post-processing: {e}")
return None
def visualize_prediction(original_image, predicted_mask, slice_index=75):
"""
Creates a 2-panel visualization of the original image and the predicted mask.
"""
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
# Show FLAIR channel for the original image
axes[0].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
axes[0].set_title('Original Image (FLAIR)', fontsize=16)
axes[0].axis('off')
axes[1].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone')
# Overlay the predicted mask directly
mask_slice = np.rot90(predicted_mask[:, :, slice_index])
axes[1].imshow(np.ma.masked_where(mask_slice == 0, mask_slice), cmap='jet', alpha=0.5)
axes[1].set_title('Predicted Tumor Mask', fontsize=16)
axes[1].axis('off')
return fig
def combine_nifti_files(t1_file_path, t1ce_file_path, t2_file_path, flair_file_path):
"""
Combines four 3D NIfTI files from given paths into a single 4D NIfTI file object.
Args:
t1_file_path, t1ce_file_path, t2_file_path, flair_file_path (str): Paths to the temporary NIfTI files.
Returns:
nib.Nifti1Image: A 4D NIfTI image object.
"""
try:
# Load the four 3D NIfTI files from file paths
t1_img = nib.load(t1_file_path)
t1ce_img = nib.load(t1ce_file_path)
t2_img = nib.load(t2_file_path)
flair_img = nib.load(flair_file_path)
# Get the image data as NumPy arrays
t1_data = t1_img.get_fdata()
t1ce_data = t1ce_img.get_fdata()
t2_data = t2_img.get_fdata()
flair_data = flair_img.get_fdata()
# Ensure all files have the same shape
if not (t1_data.shape == t1ce_data.shape == t2_data.shape == flair_data.shape):
st.error("Error: Input NIfTI files do not have matching dimensions.")
return None
# Stack the 3D arrays along a new (4th) dimension to create a 4D array
combined_data = np.stack([t1_data, t1ce_data, t2_data, flair_data], axis=-1)
# Create a new 4D NIfTI image object
combined_img = nib.Nifti1Image(combined_data, t1_img.affine)
return combined_img
except Exception as e:
st.error(f"Error combining NIfTI files: {e}")
return None