Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from scipy.ndimage import binary_closing | |
| from scipy import ndimage | |
| from utils.functions import normalize, reimburse_conform | |
| def crop(voxel, model, device): | |
| """ | |
| Apply a neural network-based cropping operation on 3D voxel data. | |
| This function slides a 3-slice window across the input volume along the first axis | |
| and predicts a binary mask for each slice using the given model. The outputs are then | |
| aggregated into a full 3D prediction volume. | |
| Args: | |
| voxel (numpy.ndarray): Input 3D array of shape (N, 256, 256). The first dimension | |
| corresponds to the slice index (typically coronal or sagittal). | |
| model (torch.nn.Module): The trained PyTorch model that predicts binary masks | |
| for each input slice triplet. | |
| device (torch.device): The device (CPU, CUDA, or MPS) on which inference will run. | |
| Returns: | |
| torch.Tensor: The predicted 3D binary mask of shape (256, 256, 256). | |
| """ | |
| # Pad the input volume by one slice at each end to allow 3-slice context | |
| voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min()) | |
| model.eval() | |
| with torch.inference_mode(): | |
| box = torch.zeros(256, 256, 256) | |
| # Iterate through each target slice and predict using a 3-slice input context | |
| for i in range(1, 257): | |
| image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]]) | |
| image = torch.tensor(image.reshape(1, 3, 256, 256)).to(device) | |
| x_out = torch.sigmoid(model(image)).detach().cpu() | |
| box[i - 1] = x_out | |
| return box.reshape(256, 256, 256) | |
| def closing(voxel): | |
| """ | |
| Perform a binary closing operation on a 3D voxel array. | |
| This function applies a binary closing operation using a 3x3x3 structuring element | |
| and performs the operation for a specified number of iterations. | |
| Parameters: | |
| voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed. | |
| Returns: | |
| numpy.ndarray: The voxel data after the binary closing operation. | |
| """ | |
| selem = np.ones((3, 3, 3), dtype="bool") | |
| voxel = binary_closing(voxel, structure=selem, iterations=3) | |
| return voxel | |
| def cropping(output_dir, basename, odata, data, cnet, device): | |
| """ | |
| Crops the input medical imaging data using a neural network model. | |
| Args: | |
| data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format. | |
| cnet (torch.nn.Module): The neural network model used for cropping. | |
| device (torch.device): The device (CPU or GPU) on which the model is run. | |
| Returns: | |
| numpy.ndarray: The cropped medical imaging data. | |
| """ | |
| voxel = data.get_fdata().astype("float32") | |
| voxel = normalize(voxel, "cropping") | |
| coronal = voxel.transpose(1, 2, 0) | |
| sagittal = voxel | |
| out_c = crop(coronal, cnet, device).permute(2, 0, 1) | |
| out_s = crop(sagittal, cnet, device) | |
| out_e = ((out_c + out_s) / 2) > 0.5 | |
| out_e = out_e.cpu().numpy() | |
| out_e = closing(out_e) | |
| cropped = data.get_fdata().astype("float32") * out_e | |
| out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e) | |
| # Compute center of mass for the masked brain | |
| x, y, z = map(int, ndimage.center_of_mass(out_e)) | |
| # Compute shifts required to center the brain | |
| xd = 128 - x | |
| yd = 120 - y | |
| zd = 128 - z | |
| # Translate (roll) the image to center the brain region | |
| cropped = np.roll(cropped, (xd, yd, zd), axis=(0, 1, 2)) | |
| # Crop out boundary padding to reduce size and focus on the centered brain | |
| cropped = cropped[16:-16, 16:-16, 16:-16] | |
| return cropped, (xd, yd, zd), out_filename | |