import numpy as np import torch from scipy.ndimage import binary_closing from scipy import ndimage from utils.functions import normalize, reimburse_conform def crop(voxel, model, device): """ Apply a neural network-based cropping operation on 3D voxel data. This function slides a 3-slice window across the input volume along the first axis and predicts a binary mask for each slice using the given model. The outputs are then aggregated into a full 3D prediction volume. Args: voxel (numpy.ndarray): Input 3D array of shape (N, 256, 256). The first dimension corresponds to the slice index (typically coronal or sagittal). model (torch.nn.Module): The trained PyTorch model that predicts binary masks for each input slice triplet. device (torch.device): The device (CPU, CUDA, or MPS) on which inference will run. Returns: torch.Tensor: The predicted 3D binary mask of shape (256, 256, 256). """ # Pad the input volume by one slice at each end to allow 3-slice context voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min()) model.eval() with torch.inference_mode(): box = torch.zeros(256, 256, 256) # Iterate through each target slice and predict using a 3-slice input context for i in range(1, 257): image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]]) image = torch.tensor(image.reshape(1, 3, 256, 256)).to(device) x_out = torch.sigmoid(model(image)).detach().cpu() box[i - 1] = x_out return box.reshape(256, 256, 256) def closing(voxel): """ Perform a binary closing operation on a 3D voxel array. This function applies a binary closing operation using a 3x3x3 structuring element and performs the operation for a specified number of iterations. Parameters: voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed. Returns: numpy.ndarray: The voxel data after the binary closing operation. """ selem = np.ones((3, 3, 3), dtype="bool") voxel = binary_closing(voxel, structure=selem, iterations=3) return voxel def cropping(output_dir, basename, odata, data, cnet, device): """ Crops the input medical imaging data using a neural network model. Args: data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format. cnet (torch.nn.Module): The neural network model used for cropping. device (torch.device): The device (CPU or GPU) on which the model is run. Returns: numpy.ndarray: The cropped medical imaging data. """ voxel = data.get_fdata().astype("float32") voxel = normalize(voxel, "cropping") coronal = voxel.transpose(1, 2, 0) sagittal = voxel out_c = crop(coronal, cnet, device).permute(2, 0, 1) out_s = crop(sagittal, cnet, device) out_e = ((out_c + out_s) / 2) > 0.5 out_e = out_e.cpu().numpy() out_e = closing(out_e) cropped = data.get_fdata().astype("float32") * out_e out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e) # Compute center of mass for the masked brain x, y, z = map(int, ndimage.center_of_mass(out_e)) # Compute shifts required to center the brain xd = 128 - x yd = 120 - y zd = 128 - z # Translate (roll) the image to center the brain region cropped = np.roll(cropped, (xd, yd, zd), axis=(0, 1, 2)) # Crop out boundary padding to reduce size and focus on the centered brain cropped = cropped[16:-16, 16:-16, 16:-16] return cropped, (xd, yd, zd), out_filename