File size: 13,950 Bytes
599a397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List, Optional

import torch
from monai.transforms import (
    Compose,
    DivisiblePadd,
    EnsureChannelFirstd,
    EnsureTyped,
    Lambdad,
    LoadImaged,
    Orientationd,
    RandAdjustContrastd,
    RandBiasFieldd,
    RandFlipd,
    RandGibbsNoised,
    RandHistogramShiftd,
    RandRotate90d,
    RandRotated,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    RandZoomd,
    ResizeWithPadOrCropd,
    ScaleIntensityRanged,
    ScaleIntensityRangePercentilesd,
    SelectItemsd,
    Spacingd,
    SpatialPadd,
)

SUPPORT_MODALITIES = ["ct", "mri"]


def define_fixed_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
    """

    Define fixed intensity transform based on the modality.



    Args:

        modality (str): The imaging modality, either 'ct' or 'mri'.

        image_keys (List[str], optional): List of image keys. Defaults to ["image"].



    Returns:

        List: A list of intensity transforms.

    """
    if modality not in SUPPORT_MODALITIES:
        warnings.warn(
            f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
        )

    modality = modality.lower()  # Normalize modality to lowercase

    intensity_transforms = {
        "mri": [
            ScaleIntensityRangePercentilesd(keys=image_keys, lower=0.0, upper=99.5, b_min=0.0, b_max=1, clip=False)
        ],
        "ct": [ScaleIntensityRanged(keys=image_keys, a_min=-1000, a_max=1000, b_min=0.0, b_max=1.0, clip=True)],
    }

    if modality not in intensity_transforms:
        return []

    return intensity_transforms[modality]


def define_random_intensity_transform(modality: str, image_keys: List[str] = ["image"]) -> List:
    """

    Define random intensity transform based on the modality.



    Args:

        modality (str): The imaging modality, either 'ct' or 'mri'.

        image_keys (List[str], optional): List of image keys. Defaults to ["image"].



    Returns:

        List: A list of random intensity transforms.

    """
    modality = modality.lower()  # Normalize modality to lowercase
    if modality not in SUPPORT_MODALITIES:
        warnings.warn(
            f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
        )

    if modality == "ct":
        return []  # CT HU intensity is stable across different datasets
    elif modality == "mri":
        return [
            RandBiasFieldd(keys=image_keys, prob=0.3, coeff_range=(0.0, 0.3)),
            RandGibbsNoised(keys=image_keys, prob=0.3, alpha=(0.5, 1.0)),
            RandAdjustContrastd(keys=image_keys, prob=0.3, gamma=(0.5, 2.0)),
            RandHistogramShiftd(keys=image_keys, prob=0.05, num_control_points=10),
        ]
    else:
        return []


def define_vae_transform(

    is_train: bool,

    modality: str,

    random_aug: bool,

    k: int = 4,

    patch_size: List[int] = [128, 128, 128],

    val_patch_size: Optional[List[int]] = None,

    output_dtype: torch.dtype = torch.float32,

    spacing_type: str = "original",

    spacing: Optional[List[float]] = None,

    image_keys: List[str] = ["image"],

    label_keys: List[str] = [],

    additional_keys: List[str] = [],

    select_channel: int = 0,

) -> tuple:
    """

    Define the MAISI VAE transform pipeline for training or validation.



    Args:

        is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.

        modality (str): The imaging modality, either 'ct' or 'mri'.

        random_aug (bool): Whether to apply random data augmentation.

        k (int, optional): Patches should be divisible by k. Defaults to 4.

        patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.

        val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.

        output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.

        spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].

        spacing (Optional[List[float]], optional): Spacing values. Defaults to None.

        image_keys (List[str], optional): List of image keys. Defaults to ["image"].

        label_keys (List[str], optional): List of label keys. Defaults to [].

        additional_keys (List[str], optional): List of additional keys. Defaults to [].

        select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.



    Returns:

        tuple: A tuple containing Composed Transform train_transforms or val_transforms depending on 'is_train'.

    """
    modality = modality.lower()  # Normalize modality to lowercase
    if modality not in SUPPORT_MODALITIES:
        warnings.warn(
            f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
        )

    if spacing_type not in ["original", "fixed", "rand_zoom"]:
        raise ValueError(f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}.")

    keys = image_keys + label_keys + additional_keys
    interp_mode = ["bilinear"] * len(image_keys) + ["nearest"] * len(label_keys)

    common_transform = [
        SelectItemsd(keys=keys, allow_missing_keys=True),
        LoadImaged(keys=keys, allow_missing_keys=True),
        EnsureChannelFirstd(keys=keys, allow_missing_keys=True),
        Orientationd(keys=keys, axcodes="RAS", allow_missing_keys=True),
    ]

    if modality == "mri":
        common_transform.append(Lambdad(keys=image_keys, func=lambda x: x[select_channel : select_channel + 1, ...]))

    common_transform.extend(define_fixed_intensity_transform(modality, image_keys=image_keys))

    if spacing_type == "fixed":
        common_transform.append(
            Spacingd(keys=image_keys + label_keys, allow_missing_keys=True, pixdim=spacing, mode=interp_mode)
        )

    random_transform = []
    if is_train and random_aug:
        random_transform.extend(define_random_intensity_transform(modality, image_keys=image_keys))
        random_transform.extend(
            [RandFlipd(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axis=axis) for axis in range(3)]
            + [
                RandRotate90d(keys=keys, allow_missing_keys=True, prob=0.5, spatial_axes=axes)
                for axes in [(0, 1), (1, 2), (0, 2)]
            ]
            + [
                RandScaleIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, factors=(0.9, 1.1)),
                RandShiftIntensityd(keys=image_keys, allow_missing_keys=True, prob=0.3, offsets=0.05),
            ]
        )

        if spacing_type == "rand_zoom":
            random_transform.extend(
                [
                    RandZoomd(
                        keys=image_keys + label_keys,
                        allow_missing_keys=True,
                        prob=0.3,
                        min_zoom=0.5,
                        max_zoom=1.5,
                        keep_size=False,
                        mode=interp_mode,
                    ),
                    RandRotated(
                        keys=image_keys + label_keys,
                        allow_missing_keys=True,
                        prob=0.3,
                        range_x=0.1,
                        range_y=0.1,
                        range_z=0.1,
                        keep_size=True,
                        mode=interp_mode,
                    ),
                ]
            )

    if is_train:
        train_crop = [
            SpatialPadd(keys=keys, spatial_size=patch_size, allow_missing_keys=True),
            RandSpatialCropd(
                keys=keys, roi_size=patch_size, allow_missing_keys=True, random_size=False, random_center=True
            ),
        ]
    else:
        val_crop = (
            [DivisiblePadd(keys=keys, allow_missing_keys=True, k=k)]
            if val_patch_size is None
            else [ResizeWithPadOrCropd(keys=keys, allow_missing_keys=True, spatial_size=val_patch_size)]
        )

    final_transform = [EnsureTyped(keys=keys, dtype=output_dtype, allow_missing_keys=True)]

    if is_train:
        train_transforms = Compose(
            common_transform + random_transform + train_crop + final_transform
            if random_aug
            else common_transform + train_crop + final_transform
        )
        return train_transforms
    else:
        val_transforms = Compose(common_transform + val_crop + final_transform)
        return val_transforms


class VAE_Transform:
    """

    A class to handle MAISI VAE transformations for different modalities.

    """

    def __init__(

        self,

        is_train: bool,

        random_aug: bool,

        k: int = 4,

        patch_size: List[int] = [128, 128, 128],

        val_patch_size: Optional[List[int]] = None,

        output_dtype: torch.dtype = torch.float32,

        spacing_type: str = "original",

        spacing: Optional[List[float]] = None,

        image_keys: List[str] = ["image"],

        label_keys: List[str] = [],

        additional_keys: List[str] = [],

        select_channel: int = 0,

    ):
        """

        Initialize the VAE_Transform.



        Args:

            is_train (bool): Whether it's for training or not. If True, the output transform will consider random_aug, the cropping will use "patch_size" for random cropping. If False, the output transform will alwasy treat "random_aug" as False, will use "val_patch_size" for central cropping.

            random_aug (bool): Whether to apply random data augmentation for training.

            k (int, optional): Patches should be divisible by k. Defaults to 4.

            patch_size (List[int], optional): Size of the patches. Defaults to [128, 128, 128]. Will random crop patch for training.

            val_patch_size (Optional[List[int]], optional): Size of validation patches. Defaults to None. If None, will use the whole volume for validation. If given, will central crop a patch for validation.

            output_dtype (torch.dtype, optional): Output data type. Defaults to torch.float32.

            spacing_type (str, optional): Type of spacing. Defaults to "original". Choose from ["original", "fixed", "rand_zoom"].

            spacing (Optional[List[float]], optional): Spacing values. Defaults to None.

            image_keys (List[str], optional): List of image keys. Defaults to ["image"].

            label_keys (List[str], optional): List of label keys. Defaults to [].

            additional_keys (List[str], optional): List of additional keys. Defaults to [].

            select_channel (int, optional): Channel to select for multi-channel MRI. Defaults to 0.

        """
        if spacing_type not in ["original", "fixed", "rand_zoom"]:
            raise ValueError(
                f"spacing_type has to be chosen from ['original', 'fixed', 'rand_zoom']. Got {spacing_type}."
            )

        self.is_train = is_train
        self.transform_dict = {}

        for modality in ["ct", "mri"]:
            self.transform_dict[modality] = define_vae_transform(
                is_train=is_train,
                modality=modality,
                random_aug=random_aug,
                k=k,
                patch_size=patch_size,
                val_patch_size=val_patch_size,
                output_dtype=output_dtype,
                spacing_type=spacing_type,
                spacing=spacing,
                image_keys=image_keys,
                label_keys=label_keys,
                additional_keys=additional_keys,
                select_channel=select_channel,
            )

    def __call__(self, img: dict, fixed_modality: Optional[str] = None) -> dict:
        """

        Apply the appropriate transform to the input image.



        Args:

            img (dict): Input image dictionary.

            fixed_modality (Optional[str], optional): Fixed modality to use. Defaults to None.



        Returns:

            Composed Transform



        Raises:

            ValueError: If the modality is not 'ct' or 'mri'.

        """
        modality = fixed_modality or img["class"]
        modality = modality.lower()  # Normalize modality to lowercase
        if modality not in ["ct", "mri"]:
            warnings.warn(
                f"Intensity transform only support {SUPPORT_MODALITIES}. Got {modality}. Will not do any intensity transform and will use original intensities."
            )

        transform = self.transform_dict[modality]
        return transform(img)