File size: 3,698 Bytes
03642cc
 
 
c9d94d6
03642cc
 
 
 
 
 
c9d94d6
 
 
 
 
03642cc
 
c9d94d6
 
 
 
 
03642cc
 
c9d94d6
03642cc
c9d94d6
 
03642cc
c9d94d6
03642cc
c9d94d6
 
 
 
 
 
 
 
 
 
03642cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d94d6
03642cc
 
 
 
 
 
 
 
 
 
 
 
c9d94d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch
from scipy.ndimage import binary_closing
from scipy import ndimage

from utils.functions import normalize, reimburse_conform


def crop(voxel, model, device):
    """
    Apply a neural network-based cropping operation on 3D voxel data.

    This function slides a 3-slice window across the input volume along the first axis
    and predicts a binary mask for each slice using the given model. The outputs are then
    aggregated into a full 3D prediction volume.

    Args:
        voxel (numpy.ndarray): Input 3D array of shape (N, 256, 256). The first dimension
            corresponds to the slice index (typically coronal or sagittal).
        model (torch.nn.Module): The trained PyTorch model that predicts binary masks
            for each input slice triplet.
        device (torch.device): The device (CPU, CUDA, or MPS) on which inference will run.

    Returns:
        torch.Tensor: The predicted 3D binary mask of shape (256, 256, 256).
    """
    # Pad the input volume by one slice at each end to allow 3-slice context
    voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())
    model.eval()

    with torch.inference_mode():
        box = torch.zeros(256, 256, 256)

        # Iterate through each target slice and predict using a 3-slice input context
        for i in range(1, 257):
            image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
            image = torch.tensor(image.reshape(1, 3, 256, 256)).to(device)
            x_out = torch.sigmoid(model(image)).detach().cpu()
            box[i - 1] = x_out

        return box.reshape(256, 256, 256)


def closing(voxel):
    """
    Perform a binary closing operation on a 3D voxel array.

    This function applies a binary closing operation using a 3x3x3 structuring element
    and performs the operation for a specified number of iterations.

    Parameters:
    voxel (numpy.ndarray): A 3D numpy array representing the voxel data to be processed.

    Returns:
    numpy.ndarray: The voxel data after the binary closing operation.
    """
    selem = np.ones((3, 3, 3), dtype="bool")
    voxel = binary_closing(voxel, structure=selem, iterations=3)
    return voxel


def cropping(output_dir, basename, odata, data, cnet, device):
    """
    Crops the input medical imaging data using a neural network model.

    Args:
        data (nibabel.Nifti1Image): The input medical imaging data in NIfTI format.
        cnet (torch.nn.Module): The neural network model used for cropping.
        device (torch.device): The device (CPU or GPU) on which the model is run.

    Returns:
        numpy.ndarray: The cropped medical imaging data.
    """
    voxel = data.get_fdata().astype("float32")
    voxel = normalize(voxel, "cropping")

    coronal = voxel.transpose(1, 2, 0)
    sagittal = voxel
    out_c = crop(coronal, cnet, device).permute(2, 0, 1)
    out_s = crop(sagittal, cnet, device)
    out_e = ((out_c + out_s) / 2) > 0.5
    out_e = out_e.cpu().numpy()
    out_e = closing(out_e)
    cropped = data.get_fdata().astype("float32") * out_e

    out_filename = reimburse_conform(output_dir, basename, "cropped", odata, data, out_e)

    # Compute center of mass for the masked brain
    x, y, z = map(int, ndimage.center_of_mass(out_e))

    # Compute shifts required to center the brain
    xd = 128 - x
    yd = 120 - y
    zd = 128 - z

    # Translate (roll) the image to center the brain region
    cropped = np.roll(cropped, (xd, yd, zd), axis=(0, 1, 2))

    # Crop out boundary padding to reduce size and focus on the centered brain
    cropped = cropped[16:-16, 16:-16, 16:-16]

    return cropped, (xd, yd, zd), out_filename