Spaces:
Sleeping
Sleeping
| ''' | |
| Utilities for augmentation. Partly credit to Dr. Jo Schlemper | |
| ''' | |
| from os.path import join | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms as deftfx | |
| import dataloaders.image_transforms as myit | |
| import copy | |
| from util.consts import IMG_SIZE | |
| import time | |
| import functools | |
| def get_sabs_aug(input_size, use_3d=False): | |
| sabs_aug = { | |
| # turn flipping off as medical data has fixed orientations | |
| 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, | |
| 'affine': { | |
| 'rotate': 5, | |
| 'shift': (5, 5), | |
| 'shear': 5, | |
| 'scale': (0.9, 1.2), | |
| }, | |
| 'elastic': {'alpha': 10, 'sigma': 5}, | |
| 'patch': input_size, | |
| 'reduce_2d': True, | |
| '3d': use_3d, | |
| 'gamma_range': (0.5, 1.5) | |
| } | |
| return sabs_aug | |
| def get_sabs_augv3(input_size): | |
| sabs_augv3 = { | |
| 'flip': {'v': False, 'h': False, 't': False, 'p': 0.25}, | |
| 'affine': { | |
| 'rotate': 30, | |
| 'shift': (30, 30), | |
| 'shear': 30, | |
| 'scale': (0.8, 1.3), | |
| }, | |
| 'elastic': {'alpha': 20, 'sigma': 5}, | |
| 'patch': input_size, | |
| 'reduce_2d': True, | |
| 'gamma_range': (0.2, 1.8) | |
| } | |
| return sabs_augv3 | |
| def get_aug(which_aug, input_size): | |
| if which_aug == 'sabs_aug': | |
| return get_sabs_aug(input_size) | |
| elif which_aug == 'aug_v3': | |
| return get_sabs_augv3(input_size) | |
| else: | |
| raise NotImplementedError | |
| # augs = { | |
| # 'sabs_aug': get_sabs_aug, | |
| # 'aug_v3': get_sabs_augv3, # more aggresive | |
| # } | |
| def get_geometric_transformer(aug, order=3): | |
| """order: interpolation degree. Select order=0 for augmenting segmentation """ | |
| affine = aug['aug'].get('affine', 0) | |
| alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] | |
| sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] | |
| flip = aug['aug'].get( | |
| 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) | |
| tfx = [] | |
| if 'flip' in aug['aug']: | |
| tfx.append(myit.RandomFlip3D(**flip)) | |
| if 'affine' in aug['aug']: | |
| tfx.append(myit.RandomAffine(affine.get('rotate'), | |
| affine.get('shift'), | |
| affine.get('shear'), | |
| affine.get('scale'), | |
| affine.get('scale_iso', True), | |
| order=order)) | |
| if 'elastic' in aug['aug']: | |
| tfx.append(myit.ElasticTransform(alpha, sigma)) | |
| input_transform = deftfx.Compose(tfx) | |
| return input_transform | |
| def get_geometric_transformer_3d(aug, order=3): | |
| """order: interpolation degree. Select order=0 for augmenting segmentation """ | |
| affine = aug['aug'].get('affine', 0) | |
| alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha'] | |
| sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma'] | |
| flip = aug['aug'].get( | |
| 'flip', {'v': True, 'h': True, 't': True, 'p': 0.125}) | |
| tfx = [] | |
| if 'flip' in aug['aug']: | |
| tfx.append(myit.RandomFlip3D(**flip)) | |
| if 'affine' in aug['aug']: | |
| tfx.append(myit.RandomAffine(affine.get('rotate'), | |
| affine.get('shift'), | |
| affine.get('shear'), | |
| affine.get('scale'), | |
| affine.get('scale_iso', True), | |
| order=order, | |
| use_3d=True)) | |
| if 'elastic' in aug['aug']: | |
| tfx.append(myit.ElasticTransform(alpha, sigma)) | |
| input_transform = deftfx.Compose(tfx) | |
| return input_transform | |
| def gamma_transform(img, aug): | |
| gamma_range = aug['aug']['gamma_range'] | |
| if isinstance(gamma_range, tuple): | |
| gamma = np.random.rand() * \ | |
| (gamma_range[1] - gamma_range[0]) + gamma_range[0] | |
| cmin = img.min() | |
| irange = (img.max() - cmin + 1e-5) | |
| img = img - cmin + 1e-5 | |
| img = irange * np.power(img * 1.0 / irange, gamma) | |
| img = img + cmin | |
| elif gamma_range == False: | |
| pass | |
| else: | |
| raise ValueError( | |
| "Cannot identify gamma transform range {}".format(gamma_range)) | |
| return img | |
| def get_intensity_transformer(aug): | |
| """some basic intensity transforms""" | |
| return functools.partial(gamma_transform, aug=aug) | |
| def transform_with_label(aug): | |
| """ | |
| Doing image geometric transform | |
| Proposed image to have the following configurations | |
| [H x W x C + CL] | |
| Where CL is the number of channels for the label. It is NOT in one-hot form | |
| """ | |
| geometric_tfx = get_geometric_transformer(aug) | |
| intensity_tfx = get_intensity_transformer(aug) | |
| def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs): | |
| """ | |
| Args | |
| comp: a numpy array with shape [H x W x C + c_label] | |
| c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1) | |
| nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label | |
| """ | |
| comp = copy.deepcopy(comp) | |
| if (use_onehot is True) and (c_label != 1): | |
| raise NotImplementedError( | |
| "Only allow compact label, also the label can only be 2d") | |
| assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label" | |
| # geometric transform | |
| _label = comp[..., c_img] | |
| _h_label = np.float32(np.arange(nclass) == (_label[..., None])) | |
| # _h_label = np.float32(_label[..., None]) | |
| comp = np.concatenate([comp[..., :c_img], _h_label], -1) | |
| comp = geometric_tfx(comp) | |
| # round one_hot labels to 0 or 1 | |
| t_label_h = comp[..., c_img:] | |
| t_label_h = np.rint(t_label_h) | |
| assert t_label_h.max() <= 1 | |
| t_img = comp[..., 0: c_img] | |
| # intensity transform | |
| t_img = intensity_tfx(t_img) | |
| if use_onehot is True: | |
| t_label = t_label_h | |
| else: | |
| t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1) | |
| return t_img, t_label | |
| return transform | |
| def transform(scan, label, nclass, geometric_tfx, intensity_tfx): | |
| """ | |
| Args | |
| scan: a numpy array with shape [D x H x W x C] | |
| label: a numpy array with shape [D x H x W x 1] | |
| """ | |
| assert len(scan.shape) == 4, "Input scan must be 4D" | |
| if len(label.shape) == 3: | |
| label = np.expand_dims(label, -1) | |
| # geometric transform | |
| comp = copy.deepcopy(np.concatenate( | |
| [scan, label], -1)) # [D x H x W x C + 1] | |
| _label = comp[..., -1] | |
| _h_label = np.float32(np.arange(nclass) == (_label[..., None])) | |
| comp = np.concatenate([comp[..., :-1], _h_label], -1) | |
| # change comp to be H x W x D x C + 1 | |
| comp = np.transpose(comp, (1, 2, 0, 3)) | |
| comp = geometric_tfx(comp) | |
| t_label_h = comp[..., 1:] | |
| t_label_h = np.rint(t_label_h) | |
| assert t_label_h.max() <= 1 | |
| t_img = comp[..., 0:1] | |
| # intensity transform | |
| t_img = intensity_tfx(t_img) | |
| return t_img, t_label_h | |
| def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx): | |
| return transform(scan, label, nclass, geometric_tfx, intensity_tfx) | |