# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from copy import deepcopy import numpy as np from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \ SegChannelSelectionTransform from batchgenerators.transforms.color_transforms import GammaTransform from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ MaskTransform, ConvertSegmentationToRegionsTransform from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ ApplyRandomBinaryOperatorTransform, \ RemoveRandomConnectedComponentFromOneHotEncodingTransform try: from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter except ImportError as ie: NonDetMultiThreadedAugmenter = None default_3D_augmentation_params = { "selected_data_channels": None, "selected_seg_channels": None, "do_elastic": True, "elastic_deform_alpha": (0., 900.), "elastic_deform_sigma": (9., 13.), "p_eldef": 0.2, "do_scaling": True, "scale_range": (0.85, 1.25), "independent_scale_factor_for_each_axis": False, "p_independent_scale_per_axis": 1, "p_scale": 0.2, "do_rotation": True, "rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), "rotation_p_per_axis": 1, "p_rot": 0.2, "random_crop": False, "random_crop_dist_to_border": None, "do_gamma": True, "gamma_retain_stats": True, "gamma_range": (0.7, 1.5), "p_gamma": 0.3, "do_mirror": True, "mirror_axes": (0, 1, 2), "dummy_2D": False, "mask_was_used_for_normalization": None, "border_mode_data": "constant", "all_segmentation_labels": None, # used for cascade "move_last_seg_chanel_to_data": False, # used for cascade "cascade_do_cascade_augmentations": False, # used for cascade "cascade_random_binary_transform_p": 0.4, "cascade_random_binary_transform_p_per_label": 1, "cascade_random_binary_transform_size": (1, 8), "cascade_remove_conn_comp_p": 0.2, "cascade_remove_conn_comp_max_size_percent_threshold": 0.15, "cascade_remove_conn_comp_fill_with_other_class_p": 0.0, "do_additive_brightness": False, "additive_brightness_p_per_sample": 0.15, "additive_brightness_p_per_channel": 0.5, "additive_brightness_mu": 0.0, "additive_brightness_sigma": 0.1, "num_threads": 12 if 'nnUNet_n_proc_DA' not in os.environ else int(os.environ['nnUNet_n_proc_DA']), "num_cached_per_thread": 1, } default_2D_augmentation_params = deepcopy(default_3D_augmentation_params) default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.) default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.) default_2D_augmentation_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi) default_2D_augmentation_params["rotation_y"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) default_2D_augmentation_params["rotation_z"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi) # sometimes you have 3d data and a 3d net but cannot augment them properly in 3d due to anisotropy (which is currently # not supported in batchgenerators). In that case you can 'cheat' and transfer your 3d data into 2d data and # transform them back after augmentation default_2D_augmentation_params["dummy_2D"] = False default_2D_augmentation_params["mirror_axes"] = (0, 1) # this can be (0, 1, 2) if dummy_2D=True def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): if isinstance(rot_x, (tuple, list)): rot_x = max(np.abs(rot_x)) if isinstance(rot_y, (tuple, list)): rot_y = max(np.abs(rot_y)) if isinstance(rot_z, (tuple, list)): rot_z = max(np.abs(rot_z)) rot_x = min(90 / 360 * 2. * np.pi, rot_x) rot_y = min(90 / 360 * 2. * np.pi, rot_y) rot_z = min(90 / 360 * 2. * np.pi, rot_z) from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d coords = np.array(final_patch_size) final_shape = np.copy(coords) if len(coords) == 3: final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) elif len(coords) == 2: final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) final_shape /= min(scale_range) return final_shape.astype(int) def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size tr_transforms.append(SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val if __name__ == "__main__": from nnunet.training.dataloading.dataset_loading import DataLoader3D, load_dataset from nnunet.paths import preprocessing_output_dir import os import pickle t = "Task002_Heart" p = os.path.join(preprocessing_output_dir, t) dataset = load_dataset(p, 0) with open(os.path.join(p, "plans.pkl"), 'rb') as f: plans = pickle.load(f) basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size), default_3D_augmentation_params['rotation_x'], default_3D_augmentation_params['rotation_y'], default_3D_augmentation_params['rotation_z'], default_3D_augmentation_params['scale_range']) dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1) tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int))