File size: 7,395 Bytes
599a397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and


import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai.transforms import AsDiscrete


def find_label_center_loc(x):
    """

    Find the center location of non-zero elements in a binary mask.



    Args:

        x (torch.Tensor): Binary mask tensor. Expected shape: [H, W, D] or [C, H, W, D].



    Returns:

        list: Center locations for each dimension. Each element is either

              the middle index of non-zero locations or None if no non-zero elements exist.

    """
    label_loc = torch.where(x != 0)
    center_loc = []
    for loc in label_loc:
        unique_loc = torch.unique(loc)
        if len(unique_loc) == 0:
            center_loc.append(None)
        else:
            center_loc.append(unique_loc[len(unique_loc) // 2])

    return center_loc


def normalize_label_to_uint8(colorize, label, n_label):
    """

    Normalize and colorize a label tensor to a uint8 image.



    Args:

        colorize (torch.Tensor): Weight tensor for colorization. Expected shape: [3, n_label, 1, 1].

        label (torch.Tensor): Input label tensor. Expected shape: [1, H, W].

        n_label (int): Number of unique labels.



    Returns:

        numpy.ndarray: Normalized and colorized image as uint8 numpy array. Shape: [H, W, 3].

    """
    with torch.no_grad():
        post_label = AsDiscrete(to_onehot=n_label)
        label = post_label(label).permute(1, 0, 2, 3)
        label = F.conv2d(label, weight=colorize)
        label = torch.clip(label, 0, 1).squeeze().permute(1, 2, 0).cpu().numpy()

    draw_img = (label * 255).astype(np.uint8)

    return draw_img


def visualize_one_slice_in_3d(image, axis: int = 2, center=None, mask_bool=True, n_label=105, colorize=None):
    """

    Extract and visualize a 2D slice from a 3D image or label tensor.



    Args:

        image (torch.Tensor): Input 3D image or label tensor. Expected shape: [1, H, W, D].

        axis (int, optional): Axis along which to extract the slice (0, 1, or 2). Defaults to 2.

        center (int, optional): Index of the slice to extract. If None, the middle slice is used.

        mask_bool (bool, optional): If True, treat the input as a label mask and normalize it. Defaults to True.

        n_label (int, optional): Number of labels in the mask. Used only if mask_bool is True. Defaults to 105.

        colorize (torch.Tensor, optional): Colorization weights for label normalization.

                                           Expected shape: [3, n_label, 1, 1] if provided.



    Returns:

        numpy.ndarray: 2D slice of the input. If mask_bool is True, returns a normalized uint8 array

                       with shape [3, H, W]. Otherwise, returns a float32 array with shape [3, H, W].



    Raises:

        ValueError: If the specified axis is not 0, 1, or 2.

    """
    # draw image
    if center is None:
        center = image.shape[2:][axis] // 2
    if axis == 0:
        draw_img = image[..., center, :, :]
    elif axis == 1:
        draw_img = image[..., :, center, :]
    elif axis == 2:
        draw_img = image[..., :, :, center]
    else:
        raise ValueError("axis should be in [0,1,2]")
    if mask_bool:
        draw_img = normalize_label_to_uint8(colorize, draw_img, n_label)
    else:
        draw_img = draw_img.squeeze().cpu().numpy().astype(np.float32)
        draw_img = np.stack((draw_img,) * 3, axis=-1)
    return draw_img


def show_image(image, title="mask"):
    """

    Plot and display an input image.



    Args:

        image (numpy.ndarray): Image to be displayed. Expected shape: [H, W] for grayscale or [H, W, 3] for RGB.

        title (str, optional): Title for the plot. Defaults to "mask".

    """
    plt.figure("check", (24, 12))
    plt.subplot(1, 2, 1)
    plt.title(title)
    plt.imshow(image)
    plt.show()


def to_shape(a, shape):
    """

    Pad an image to a desired shape.



    This function pads a 3D numpy array (image) with zeros to reach the specified shape.

    The padding is added equally on both sides of each dimension, with any odd padding

    added to the end.



    Args:

        a (numpy.ndarray): Input 3D array to be padded. Expected shape: [X, Y, Z].

        shape (tuple): Desired output shape as (x_, y_, z_).



    Returns:

        numpy.ndarray: Padded array with the desired shape [x_, y_, z_].



    Note:

        If the input shape is larger than the desired shape in any dimension,

        no padding is removed; the original size is maintained for that dimension.

        Padding is done using numpy's pad function with 'constant' mode (zero-padding).

    """
    x_, y_, z_ = shape
    x, y, z = a.shape
    x_pad = x_ - x
    y_pad = y_ - y
    z_pad = z_ - z
    return np.pad(
        a,
        (
            (x_pad // 2, x_pad // 2 + x_pad % 2),
            (y_pad // 2, y_pad // 2 + y_pad % 2),
            (z_pad // 2, z_pad // 2 + z_pad % 2),
        ),
        mode="constant",
    )


def get_xyz_plot(image, center_loc_axis, mask_bool=True, n_label=105, colorize=None, target_class_index=0):
    """

    Generate a concatenated XYZ plot of 2D slices from a 3D image.



    This function creates visualizations of three orthogonal slices (XY, XZ, YZ) from a 3D image

    and concatenates them into a single 2D image.



    Args:

        image (torch.Tensor): Input 3D image tensor. Expected shape: [1, H, W, D].

        center_loc_axis (list): List of three integers specifying the center locations for each axis.

        mask_bool (bool, optional): Whether to apply masking. Defaults to True.

        n_label (int, optional): Number of labels for visualization. Defaults to 105.

        colorize (torch.Tensor, optional): Colorization weights. Expected shape: [3, n_label, 1, 1] if provided.

        target_class_index (int, optional): Index of the target class. Defaults to 0.



    Returns:

        numpy.ndarray: Concatenated 2D image of the three orthogonal slices. Shape: [max(H,W,D), 3*max(H,W,D), 3].



    Note:

        The output image is padded to ensure all slices have the same dimensions.

    """
    target_shape = list(image.shape[1:])  # [1,H,W,D]
    img_list = []

    for axis in range(3):
        center = center_loc_axis[axis]

        img = visualize_one_slice_in_3d(
            torch.flip(image.unsqueeze(0), [-3, -2, -1]),
            axis,
            center=center,
            mask_bool=mask_bool,
            n_label=n_label,
            colorize=colorize,
        )
        img = img.transpose([2, 1, 0])

        img = to_shape(img, (3, max(target_shape), max(target_shape)))
        img_list.append(img)
        img = np.concatenate(img_list, axis=2).transpose([1, 2, 0])
    return img