File size: 4,906 Bytes
03642cc
 
 
 
 
 
c9d94d6
 
 
 
 
 
 
03642cc
c9d94d6
 
 
 
 
03642cc
 
c9d94d6
 
 
 
 
03642cc
 
c9d94d6
 
03642cc
 
c9d94d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03642cc
c9d94d6
 
 
 
03642cc
c9d94d6
 
 
 
 
 
03642cc
c9d94d6
03642cc
c9d94d6
 
 
03642cc
c9d94d6
 
03642cc
c9d94d6
03642cc
c9d94d6
 
03642cc
c9d94d6
 
 
 
 
 
03642cc
 
c9d94d6
 
 
03642cc
 
c9d94d6
03642cc
c9d94d6
 
03642cc
c9d94d6
03642cc
 
 
 
c9d94d6
 
 
 
03642cc
 
c9d94d6
 
 
 
03642cc
 
c9d94d6
03642cc
 
c9d94d6
 
 
 
 
03642cc
 
c9d94d6
 
03642cc
 
c9d94d6
03642cc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
import torch

from utils.functions import normalize


def parcellate(
    voxel: np.ndarray,
    model: torch.nn.Module,
    device: torch.device,
    mode: str,
    n_classes: int = 142,
) -> torch.Tensor:
    """
    Perform 2.5D neural network inference for brain parcellation along a specific anatomical plane.

    The function processes a 3D volume slice by slice using a 3-slice context window (previous,
    current, next). An additional constant-valued fourth channel encodes the orientation mode
    (Axial, Coronal, or Sagittal), allowing the network to distinguish the processing plane.

    Args:
        voxel (numpy.ndarray): 3D voxel data of shape (N, 224, 224), representing a single anatomical view.
        model (torch.nn.Module): The trained PyTorch parcellation model.
        device (torch.device): Device for inference (CPU, CUDA, or MPS).
        mode (str): The anatomical plane used for inference. Must be one of {'Axial', 'Coronal', 'Sagittal'}.
        n_classes (int, optional): Number of output anatomical labels. Defaults to 142.

    Returns:
        torch.Tensor: A tensor of shape (224, n_classes, 224, 224) containing softmax probabilities
        for each class at each voxel position.
    """
    model.eval()
    voxel = voxel.astype(np.float32)

    # Set the constant value for the 4th channel to encode plane orientation
    if mode == "Axial":
        section_value = 1.0
    elif mode == "Coronal":
        section_value = -1.0
    elif mode == "Sagittal":
        section_value = 0.0
    else:
        raise ValueError("mode must be one of {'Axial','Coronal','Sagittal'}")

    # Pad one slice on both ends to safely allow 3-slice context
    voxel_pad = np.pad(
        voxel,
        [(1, 1), (0, 0), (0, 0)],
        mode="constant",
        constant_values=float(voxel.min()),
    )

    # Initialize a container for the network outputs (CPU for accumulation)
    box = torch.empty((224, n_classes, 224, 224), dtype=torch.float32, device="cpu")

    # Inference loop: iterate over slices and feed triplets to the model
    with torch.inference_mode():
        for i in range(1, 225):
            prev_ = voxel_pad[i - 1]
            curr_ = voxel_pad[i]
            next_ = voxel_pad[i + 1]

            # Build 4-channel input (3 context slices + orientation encoding)
            four_ch = np.empty((4, 224, 224), dtype=np.float32)
            four_ch[0] = prev_
            four_ch[1] = curr_
            four_ch[2] = next_
            four_ch[3].fill(section_value)

            inp = torch.from_numpy(four_ch).unsqueeze(0).to(device)

            # Model inference with softmax normalization
            logits = model(inp)
            probs = torch.softmax(logits, dim=1)

            # Store softmax output for this slice
            box[i - 1] = probs

    return box


def parcellation(voxel, pnet, device):
    """
    Perform full 3D brain parcellation by aggregating predictions across multiple anatomical planes.

    The function normalizes the input MRI volume, generates three differently oriented representations
    (coronal, sagittal, axial), and performs 2.5D inference on each using a shared parcellation network.
    The resulting probability maps are fused by summation and converted into a discrete segmentation map
    via argmax over anatomical classes.

    Args:
        voxel (numpy.ndarray): Input 3D brain volume (float array).
        pnet (torch.nn.Module): Trained parcellation network (U-Net or similar architecture).
        device (torch.device): Device on which inference will be executed (CPU or GPU).

    Returns:
        numpy.ndarray: Final 3D parcellation map (integer label image) with voxel-wise anatomical labels.
    """
    # Normalize input intensities for network inference
    voxel = normalize(voxel, "parcellation")

    # Prepare three anatomical views for 2.5D inference
    coronal = voxel.transpose(1, 2, 0)
    sagittal = voxel
    axial = voxel.transpose(2, 1, 0)

    # ------------------------
    # Coronal view inference
    # ------------------------
    out_c = parcellate(coronal, pnet, device, "Coronal").permute(1, 3, 0, 2)
    torch.cuda.empty_cache()

    # ------------------------
    # Sagittal view inference
    # ------------------------
    out_s = parcellate(sagittal, pnet, device, "Sagittal").permute(1, 0, 2, 3)
    torch.cuda.empty_cache()

    # Fuse coronal and sagittal predictions
    out_e = out_c + out_s
    del out_c, out_s

    # ------------------------
    # Axial view inference
    # ------------------------
    out_a = parcellate(axial, pnet, device, "Axial").permute(1, 3, 2, 0)
    torch.cuda.empty_cache()

    # Combine outputs from all three anatomical orientations
    out_e = out_e + out_a
    del out_a

    # Convert probability maps to final integer labels
    parcellated = torch.argmax(out_e, 0).numpy()

    return parcellated