File size: 4,178 Bytes
03642cc
 
c9d94d6
03642cc
 
 
c9d94d6
03642cc
c9d94d6
 
 
 
 
 
 
03642cc
 
c9d94d6
 
 
03642cc
 
c9d94d6
 
03642cc
 
 
c9d94d6
 
 
03642cc
c9d94d6
 
03642cc
c9d94d6
 
 
 
03642cc
c9d94d6
 
 
03642cc
c9d94d6
 
03642cc
c9d94d6
 
03642cc
c9d94d6
 
 
 
 
 
 
03642cc
 
c9d94d6
 
 
03642cc
 
c9d94d6
 
 
 
03642cc
c9d94d6
 
03642cc
c9d94d6
03642cc
 
 
c9d94d6
 
 
03642cc
c9d94d6
03642cc
 
c9d94d6
 
03642cc
c9d94d6
03642cc
 
c9d94d6
 
 
 
 
 
 
03642cc
 
c9d94d6
 
 
03642cc
 
c9d94d6
03642cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from scipy.ndimage import binary_dilation
import numpy as np
from utils.functions import normalize


def separate(voxel, model, device):
    """
    Perform slice-wise inference using a hemisphere separation model.

    This function runs a 2.5D neural network across slices of a 3D input volume.
    Each slice is processed in the context of its immediate neighbors (previous
    and next slices) to improve spatial coherence. The model outputs a
    three-class probability map distinguishing background, left hemisphere,
    and right hemisphere regions.

    Args:
        voxel (numpy.ndarray): Input voxel data of shape (N, 224, 224).
        model (torch.nn.Module): Trained hemisphere segmentation model (U-Net architecture).
        device (torch.device): Computational device (CPU, CUDA, or MPS).

    Returns:
        torch.Tensor: A tensor of shape (224, 3, 224, 224) containing softmax
        probabilities for each class at every voxel.
    """
    model.eval()

    # Pad the volume by one slice on both ends to provide full 3-slice context
    voxel = np.pad(voxel, [(1, 1), (0, 0), (0, 0)], "constant", constant_values=voxel.min())

    with torch.inference_mode():
        # Output tensor for storing model predictions (class probabilities)
        box = torch.zeros(224, 3, 224, 224)

        # Iterate slice-by-slice along the first axis
        for i in range(1, 225):
            image = np.stack([voxel[i - 1], voxel[i], voxel[i + 1]])
            image = torch.tensor(image.reshape(1, 3, 224, 224)).to(device)

            # Model inference with softmax normalization across classes
            x_out = torch.softmax(model(image), dim=1).detach().cpu()
            box[i - 1] = x_out

        # Return complete 3D probability map
        return box.reshape(224, 3, 224, 224)


def hemisphere(voxel, hnet, device):
    """
    Perform hemisphere separation on a brain MRI volume using a deep learning model.

    The function predicts left and right hemisphere regions from a normalized
    3D MRI volume using multi-view inference (coronal and transverse planes).
    Predictions from both orientations are fused to improve robustness. The final
    label map is post-processed using binary dilation to smooth and expand hemisphere
    boundaries, ensuring anatomical continuity.

    Args:
        voxel (numpy.ndarray): Input 3D brain volume to be separated into hemispheres.
        hnet (torch.nn.Module): Trained hemisphere segmentation model.
        device (torch.device): Target device for computation (e.g., 'cuda', 'cpu').

    Returns:
        numpy.ndarray: A 3D integer array representing the hemisphere mask:
            - 0: Background
            - 1: Left hemisphere
            - 2: Right hemisphere
    """
    # Normalize voxel intensities for inference
    voxel = normalize(voxel, "hemisphere")

    # Prepare different anatomical orientations for inference
    coronal = voxel.transpose(1, 2, 0)
    transverse = voxel.transpose(2, 1, 0)

    # Perform inference for both coronal and transverse orientations
    out_c = separate(coronal, hnet, device).permute(1, 3, 0, 2)
    out_a = separate(transverse, hnet, device).permute(1, 3, 2, 0)

    # Fuse both outputs by summing class probabilities
    out_e = out_c + out_a

    # Determine final class labels (0, 1, or 2) by selecting the most probable class
    out_e = torch.argmax(out_e, dim=0).cpu().numpy()

    # Release any residual GPU memory
    torch.cuda.empty_cache()

    # --------------------------
    # Post-processing step: binary dilation
    # --------------------------

    # First, dilate the left hemisphere (class 1)
    dilated_mask_1 = binary_dilation(out_e == 1, iterations=1).astype("int16")
    # Preserve right hemisphere voxels from the original prediction
    dilated_mask_1[out_e == 2] = 2

    # Then, dilate the right hemisphere (class 2) symmetrically
    dilated_mask_2 = binary_dilation(dilated_mask_1 == 2, iterations=1).astype("int16") * 2
    # Restore left hemisphere voxels to prevent overwriting
    dilated_mask_2[dilated_mask_1 == 1] = 1

    # Return the final dilated and fused hemisphere mask
    return dilated_mask_2