| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter |
| | from batchgenerators.transforms.abstract_transforms import Compose |
| | from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \ |
| | SegChannelSelectionTransform |
| | from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ |
| | ContrastAugmentationTransform, BrightnessTransform, GammaTransform |
| | from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform |
| | from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform |
| | from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform |
| | from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor |
| | from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \ |
| | MaskTransform, ConvertSegmentationToRegionsTransform |
| | from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params |
| | from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2 |
| | from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \ |
| | ApplyRandomBinaryOperatorTransform, \ |
| | RemoveRandomConnectedComponentFromOneHotEncodingTransform |
| |
|
| | try: |
| | from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter |
| | except ImportError as ie: |
| | NonDetMultiThreadedAugmenter = None |
| |
|
| |
|
| | def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, |
| | border_val_seg=-1, |
| | seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, |
| | soft_ds=False, |
| | classes=None, pin_memory=True, regions=None): |
| | assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" |
| |
|
| | tr_transforms = [] |
| |
|
| | if params.get("selected_data_channels") is not None: |
| | tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) |
| |
|
| | if params.get("selected_seg_channels") is not None: |
| | tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) |
| |
|
| | |
| | if params.get("dummy_2D") is not None and params.get("dummy_2D"): |
| | ignore_axes = (0,) |
| | tr_transforms.append(Convert3DTo2DTransform()) |
| | patch_size_spatial = patch_size[1:] |
| | else: |
| | patch_size_spatial = patch_size |
| | ignore_axes = None |
| |
|
| | tr_transforms.append(SpatialTransform( |
| | patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), |
| | alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), |
| | do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), |
| | angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), |
| | border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, |
| | border_mode_seg="constant", border_cval_seg=border_val_seg, |
| | order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), |
| | p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), |
| | independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"), |
| | p_independent_scale_per_axis=params.get("p_independent_scale_per_axis") |
| | )) |
| |
|
| | if params.get("dummy_2D"): |
| | tr_transforms.append(Convert2DTo3DTransform()) |
| |
|
| | |
| | |
| | tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) |
| | tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, |
| | p_per_channel=0.5)) |
| | tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) |
| | tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) |
| | tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, |
| | p_per_channel=0.5, |
| | order_downsample=0, order_upsample=3, p_per_sample=0.25, |
| | ignore_axes=ignore_axes)) |
| | tr_transforms.append( |
| | GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), |
| | p_per_sample=0.15)) |
| |
|
| | if params.get("do_additive_brightness"): |
| | tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"), |
| | params.get("additive_brightness_sigma"), |
| | True, p_per_sample=params.get("additive_brightness_p_per_sample"), |
| | p_per_channel=params.get("additive_brightness_p_per_channel"))) |
| |
|
| | if params.get("do_gamma"): |
| | tr_transforms.append( |
| | GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), |
| | p_per_sample=params["p_gamma"])) |
| |
|
| | if params.get("do_mirror") or params.get("mirror"): |
| | tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) |
| |
|
| | if params.get("mask_was_used_for_normalization") is not None: |
| | mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") |
| | tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) |
| |
|
| | tr_transforms.append(RemoveLabelTransform(-1, 0)) |
| |
|
| | if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): |
| | tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) |
| | if params.get("cascade_do_cascade_augmentations") and not None and params.get( |
| | "cascade_do_cascade_augmentations"): |
| | if params.get("cascade_random_binary_transform_p") > 0: |
| | tr_transforms.append(ApplyRandomBinaryOperatorTransform( |
| | channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), |
| | p_per_sample=params.get("cascade_random_binary_transform_p"), |
| | key="data", |
| | strel_size=params.get("cascade_random_binary_transform_size"))) |
| | if params.get("cascade_remove_conn_comp_p") > 0: |
| | tr_transforms.append( |
| | RemoveRandomConnectedComponentFromOneHotEncodingTransform( |
| | channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), |
| | key="data", |
| | p_per_sample=params.get("cascade_remove_conn_comp_p"), |
| | fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), |
| | dont_do_if_covers_more_than_X_percent=params.get( |
| | "cascade_remove_conn_comp_fill_with_other_class_p"))) |
| |
|
| | tr_transforms.append(RenameTransform('seg', 'target', True)) |
| |
|
| | if regions is not None: |
| | tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) |
| |
|
| | if deep_supervision_scales is not None: |
| | if soft_ds: |
| | assert classes is not None |
| | tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) |
| | else: |
| | tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', |
| | output_key='target')) |
| |
|
| | tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) |
| | tr_transforms = Compose(tr_transforms) |
| |
|
| | batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), |
| | params.get("num_cached_per_thread"), |
| | seeds=seeds_train, pin_memory=pin_memory) |
| |
|
| | val_transforms = [] |
| | val_transforms.append(RemoveLabelTransform(-1, 0)) |
| | if params.get("selected_data_channels") is not None: |
| | val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) |
| | if params.get("selected_seg_channels") is not None: |
| | val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) |
| |
|
| | if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): |
| | val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) |
| |
|
| | val_transforms.append(RenameTransform('seg', 'target', True)) |
| |
|
| | if regions is not None: |
| | val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) |
| |
|
| | if deep_supervision_scales is not None: |
| | if soft_ds: |
| | assert classes is not None |
| | val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) |
| | else: |
| | val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', |
| | output_key='target')) |
| |
|
| | val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) |
| | val_transforms = Compose(val_transforms) |
| |
|
| | batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), |
| | params.get("num_cached_per_thread"), |
| | seeds=seeds_val, pin_memory=pin_memory) |
| | return batchgenerator_train, batchgenerator_val |