import nibabel as nib import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt import matplotlib.patches as mpatches import streamlit as st import io import tempfile from skimage.transform import resize def preprocess_nifti(nifti_file): """ Loads a NIfTI file, preprocesses it, and returns a PyTorch tensor. Args: nifti_file (str or io.BytesIO): Path to the NIfTI file or a file-like object. Returns: tuple: A tuple containing the original image data and a preprocessed tensor. """ try: nifti_img = nib.load(nifti_file) img_data = nifti_img.get_fdata() if len(img_data.shape) != 4 or img_data.shape[-1] != 4: st.error("The uploaded NIfTI file must be a 4D image with 4 channels.") return None, None for i in range(img_data.shape[-1]): channel_data = img_data[..., i] if np.max(channel_data) > 0: img_data[..., i] = channel_data / np.max(channel_data) img_data = np.transpose(img_data, (3, 2, 0, 1)) tensor_data = torch.from_numpy(img_data).float() tensor_data = torch.unsqueeze(tensor_data, 0) return nifti_img.get_fdata(), tensor_data except Exception as e: st.error(f"Error during NIfTI preprocessing: {e}") return None, None def postprocess_mask(prediction_tensor, original_shape): """ Converts model output (tensor) into a visualizable mask (numpy array) and resizes it to the original image dimensions. """ try: probabilities = F.softmax(prediction_tensor, dim=1) mask = torch.argmax(probabilities, dim=1) mask = mask.detach().cpu().numpy() mask = np.squeeze(mask) mask = np.transpose(mask, (1, 2, 0)) resized_mask = resize(mask, original_shape[:3], order=0, preserve_range=True, anti_aliasing=False) return resized_mask except Exception as e: st.error(f"Error during mask post-processing: {e}") return None def visualize_prediction(original_image, predicted_mask, slice_index=75): """ Creates a 2-panel visualization of the original image and the predicted mask. """ fig, axes = plt.subplots(1, 2, figsize=(15, 7)) # Show FLAIR channel for the original image axes[0].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone') axes[0].set_title('Original Image (FLAIR)', fontsize=16) axes[0].axis('off') axes[1].imshow(np.rot90(original_image[:, :, slice_index, 3]), cmap='bone') # Overlay the predicted mask directly mask_slice = np.rot90(predicted_mask[:, :, slice_index]) axes[1].imshow(np.ma.masked_where(mask_slice == 0, mask_slice), cmap='jet', alpha=0.5) axes[1].set_title('Predicted Tumor Mask', fontsize=16) axes[1].axis('off') return fig def combine_nifti_files(t1_file_path, t1ce_file_path, t2_file_path, flair_file_path): """ Combines four 3D NIfTI files from given paths into a single 4D NIfTI file object. Args: t1_file_path, t1ce_file_path, t2_file_path, flair_file_path (str): Paths to the temporary NIfTI files. Returns: nib.Nifti1Image: A 4D NIfTI image object. """ try: # Load the four 3D NIfTI files from file paths t1_img = nib.load(t1_file_path) t1ce_img = nib.load(t1ce_file_path) t2_img = nib.load(t2_file_path) flair_img = nib.load(flair_file_path) # Get the image data as NumPy arrays t1_data = t1_img.get_fdata() t1ce_data = t1ce_img.get_fdata() t2_data = t2_img.get_fdata() flair_data = flair_img.get_fdata() # Ensure all files have the same shape if not (t1_data.shape == t1ce_data.shape == t2_data.shape == flair_data.shape): st.error("Error: Input NIfTI files do not have matching dimensions.") return None # Stack the 3D arrays along a new (4th) dimension to create a 4D array combined_data = np.stack([t1_data, t1ce_data, t2_data, flair_data], axis=-1) # Create a new 4D NIfTI image object combined_img = nib.Nifti1Image(combined_data, t1_img.affine) return combined_img except Exception as e: st.error(f"Error combining NIfTI files: {e}") return None