"""#TODO: dont convert to numpy and back to torch.""" from collections.abc import Iterable, Sequence from itertools import chain from typing import Any import kornia.augmentation as K import numpy as np import torch from kornia.augmentation import random_generator as rg from kornia.augmentation.utils import _range_bound from kornia.constants import DataKey, Resample from typing import Optional, Tuple, Sequence, Dict, Union def default_augmenter(coords: np.ndarray): # TODO parametrize magnitude of different augmentations ndim = coords.shape[1] assert coords.ndim == 2 and ndim in (2, 3) # first remove offset center = coords.mean(axis=0, keepdims=True) coords = coords - center # apply random flip coords *= 2 * np.random.randint(0, 2, (1, ndim)) - 1 # apply rotation along the last two dimensions phi = np.random.uniform(0, 2 * np.pi) coords = _rotate(coords, phi, center=None) if ndim == 3: # rotate along the first two dimensions too phi2, phi3 = np.random.uniform(0, 2 * np.pi, 2) coords = _rotate(coords, phi2, rot_axis=(0, 1), center=None) coords = _rotate(coords, phi3, rot_axis=(0, 2), center=None) coords += center # translation trans = 128 * np.random.uniform(-1, 1, (1, ndim)) coords += trans # elastic coords += 1.5 * np.random.normal(0, 1, coords.shape) return coords def _rotate( coords: np.ndarray, phi: float, rot_axis=(-2, -1), center: Optional[Tuple] = None ): """Rotation along the last two dimensions of coords[..,:-2:].""" ndim = coords.shape[1] assert coords.ndim == 2 and ndim in (2, 3) if center is None: center = (0,) * ndim assert len(center) == ndim center = np.asarray(center) co, si = np.cos(phi), np.sin(phi) Rot = np.eye(ndim) Rot[np.ix_(rot_axis, rot_axis)] = np.array(((co, -si), (si, co))) x = coords - center x = x @ Rot.T x += center return x def _filter_points( points: np.ndarray, shape: tuple, origin: Optional[Tuple] = None ) -> np.ndarray: """Returns indices of points that are inside the shape extent and given origin.""" ndim = points.shape[-1] if origin is None: origin = (0,) * ndim idx = tuple( np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i]) for i in range(ndim) ) idx = np.where(np.all(idx, axis=0))[0] return idx class ConcatAffine(K.RandomAffine): """Concatenate multiple affine transformations without intermediates.""" def __init__(self, affines: Sequence[K.RandomAffine]): super().__init__(degrees=0) self._affines = affines if not all([a.same_on_batch for a in affines]): raise ValueError("all affines must have same_on_batch=True") def merge_params(self, params: Sequence[Dict[str, torch.Tensor]]): """Merge params from affines.""" out = params[0].copy() def _torchmax(x, dim): return torch.max(x, dim=dim).values ops = { "translations": torch.sum, "center": torch.mean, "scale": torch.prod, "shear_x": torch.sum, "shear_y": torch.sum, "angle": torch.sum, "batch_prob": _torchmax, } for k, v in params[0].items(): ps = [p[k] for p in params if len(p[k]) > 0] if len(ps) > 0 and k in ops: v_new = torch.stack(ps, dim=0).float() v_new = ops[k](v_new, dim=0) v_new = v_new.to(v.dtype) else: v_new = v out[k] = v_new return out def forward_parameters( self, batch_shape: Tuple[int, ...] ) -> Dict[str, torch.Tensor]: params = tuple(a.forward_parameters(batch_shape) for a in self._affines) # print(params) return self.merge_params(params) # custom augmentations class RandomIntensityScaleShift(K.IntensityAugmentationBase2D): r"""Apply a random scale and shift to the image intensity. Args: p: probability of applying the transformation. scale: the scale factor to apply shift: the offset to apply clip_output: if true clip output same_on_batch: apply the same transformation across the batch. keepdim: whether to keep the output shape the same as input (True) or broadcast it to the batch form (False). Shape: - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)` - Output: :math:`(B, C, H, W)` .. note:: This function internally uses :func:`kornia.enhance.adjust_brightness` """ def __init__( self, scale: Tuple[float, float] = (0.5, 2.0), shift: Tuple[float, float] = (-0.1, 0.1), clip_output: bool = True, same_on_batch: bool = False, p: float = 1.0, keepdim: bool = False, ) -> None: super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) self.scale = _range_bound( scale, "scale", center=0, bounds=(-float("inf"), float("inf")) ) self.shift = _range_bound( shift, "shift", center=0, bounds=(-float("inf"), float("inf")) ) self._param_generator = rg.PlainUniformGenerator( (self.scale, "scale_factor", None, None), (self.shift, "shift_factor", None, None), ) self.clip_output = clip_output def apply_transform( self, input: torch.Tensor, params: Dict[str, torch.Tensor], flags: Dict[str, Any], transform: Optional[torch.Tensor] = None, ) -> torch.Tensor: scale_factor = params["scale_factor"].to(input) shift_factor = params["shift_factor"].to(input) scale_factor = scale_factor.view(len(scale_factor), 1, 1, 1) shift_factor = shift_factor.view(len(scale_factor), 1, 1, 1) img_adjust = input * scale_factor + shift_factor if self.clip_output: img_adjust = img_adjust.clamp(min=0.0, max=1.0) return img_adjust class RandomTemporalAffine(K.RandomAffine): r"""Apply a random 2D affine transformation to a batch of images while varying the transformation across the time dimension from 0 to 1. Same args/kwargs as K.RandomAffine """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, same_on_batch=True, **kwargs) def forward_parameters( self, batch_shape: Tuple[int, ...] ) -> Dict[str, torch.Tensor]: params = super().forward_parameters(batch_shape) factor = torch.linspace(0, 1, batch_shape[0]).to(params["translations"]) for key in ["translations", "center", "angle", "shear_x", "shear_y"]: v = params[key] if len(v) > 0: params[key] = v * factor.view(*((-1,) + (1,) * len(v.shape[1:]))) for key in [ "scale", ]: v = params[key] if len(v) > 0: params[key] = 1 + (v - 1) * factor.view( *((-1,) + (1,) * len(v.shape[1:])) ) return params # def compute_transformation(self, input: torch.Tensor, # params: Dict[str, torch.Tensor], # flags: Dict[str, Any]) -> torch.Tensor: # factor = torch.linspace(0, 1, input.shape[0]).to(input) # for key in ["translations", "center", "angle", "shear_x", "shear_y"]: # v = params[key] # params[key] = v * factor.view(*((-1,)+(1,)*len(v.shape[1:]))) # for key in ["scale", ]: # v = params[key] # params[key] = 1 + (v-1) * factor.view(*((-1,)+(1,)*len(v.shape[1:]))) # return super().compute_transformation(input, params, flags) class BasicPipeline: """transforms img, mask, and points. Only supports 2D transformations for now (any 3D object will preserve its z coordinates/dimensions) """ def __init__(self, augs: tuple, filter_points: bool = True): self.data_keys = ("input", "mask", "keypoints") self.pipeline = K.AugmentationSequential( *augs, # disable align_corners to not trigger lots of warnings from kornia extra_args={ DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": False} }, data_keys=self.data_keys, ) self.filter_points = filter_points def __call__( self, img: np.ndarray, mask: np.ndarray, points: np.ndarray, timepoints: np.ndarray, ): ndim = img.ndim - 1 assert ( ndim in (2, 3) and points.ndim == 2 and points.shape[-1] == ndim and timepoints.ndim == 1 and img.shape == mask.shape ) x = torch.from_numpy(img).float() y = torch.from_numpy(mask.astype(np.int64)).float() # if 2D add dummy channel if ndim == 2: x = x.unsqueeze(1) y = y.unsqueeze(1) p = points[..., [1, 0]] # if 3D we use z as channel (i.e. fix augs across z) elif ndim == 3: p = points[..., [2, 1]] # flip as kornia expects xy and not yx p = torch.from_numpy(p).unsqueeze(0).float() # add batch by duplicating to make kornia happy p = p.expand(len(x), -1, -1) # create a mask to know which timepoint the points belong to ts = torch.from_numpy(timepoints).long() n_points = p.shape[1] if n_points > 0: x, y, p = self.pipeline(x, y, p) else: # dummy keypoints x, y = self.pipeline(x, y, torch.zeros((len(x), 1, 2)))[:2] # remove batch p = p[ts, torch.arange(n_points)] # flip back p = p[..., [1, 0]] # remove channel if ndim == 2: x = x.squeeze(1) y = y.squeeze(1) x = x.numpy() y = y.numpy().astype(np.uint16) # p = p.squeeze(0).numpy() p = p.numpy() # add back z coordinates if ndim == 3: p = np.concatenate([points[..., 0:1], p], axis=-1) ts = ts.numpy() # remove points outside of img/mask if self.filter_points: idx = _filter_points(p, shape=x.shape[-ndim:]) else: idx = np.arange(len(p), dtype=int) p = p[idx] return (x, y, p), idx class RandomCrop: def __init__( self, crop_size: Optional[Union[int, Tuple[int]]] = None, ndim: int = 2, ensure_inside_points: bool = False, use_padding: bool = True, padding_mode="constant", ) -> None: """crop_size: tuple of int can be tuple of length 1 (all dimensions) of length ndim (y,x,...) of length 2*ndim (y1,y2, x1,x2, ...). """ if isinstance(crop_size, int): crop_size = (crop_size,) * 2 * ndim elif isinstance(crop_size, Iterable): pass else: raise ValueError(f"{crop_size} has to be int or tuple of int") if len(crop_size) == 1: crop_size = (crop_size[0],) * 2 * ndim elif len(crop_size) == ndim: crop_size = tuple(chain(*tuple((c, c) for c in crop_size))) elif len(crop_size) == 2 * ndim: pass else: raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}") crop_size = np.array(crop_size) self._ndim = ndim self._crop_bounds = crop_size[::2], crop_size[1::2] self._use_padding = use_padding self._ensure_inside_points = ensure_inside_points self._rng = np.random.RandomState() self._padding_mode = padding_mode def crop_img(self, img: np.ndarray, corner: np.ndarray, crop_size: np.ndarray): if not img.ndim == self._ndim + 1: raise ValueError( f"img has to be 1 (time) + {self._ndim} spatial dimensions" ) pad_left = np.maximum(0, -corner) pad_right = np.maximum( 0, corner + crop_size - np.array(img.shape[-self._ndim :]) ) img = np.pad( img, ((0, 0), *tuple(np.stack((pad_left, pad_right)).T)), mode=self._padding_mode, ) slices = ( slice(None), *tuple(slice(c, c + s) for c, s in zip(corner + pad_left, crop_size)), ) return img[slices] def crop_points( self, points: np.ndarray, corner: np.ndarray, crop_size: np.ndarray ): idx = _filter_points(points, shape=crop_size, origin=corner) return points[idx] - corner, idx def __call__(self, img: np.ndarray, mask: np.ndarray, points: np.ndarray): assert ( img.ndim == self._ndim + 1 and points.ndim == 2 and points.shape[-1] == self._ndim and img.shape == mask.shape ) points = points.astype(int) crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1) # print(f'{crop_size=}') if self._ensure_inside_points: if len(points) == 0: print("No points given, cannot ensure inside points") return (img, mask, points), np.zeros((0,), int) # sample point and corner relative to it _idx = np.random.randint(len(points)) corner = ( points[_idx] - crop_size + 1 + self._rng.randint(crop_size // 4, 3 * crop_size // 4) ) else: corner = self._rng.randint( 0, np.maximum(1, np.array(img.shape[-self._ndim :]) - crop_size) ) if not self._use_padding: corner = np.maximum(0, corner) crop_size = np.minimum( crop_size, np.array(img.shape[-self._ndim :]) - corner ) img = self.crop_img(img, corner, crop_size) mask = self.crop_img(mask, corner, crop_size) points, idx = self.crop_points(points, corner, crop_size) return (img, mask, points), idx class AugmentationPipeline(BasicPipeline): """transforms img, mask, and points.""" def __init__(self, p=0.5, filter_points=True, level=1): if level == 1: augs = [ # Augmentations for all images in a window K.RandomHorizontalFlip(p=0.5, same_on_batch=True), K.RandomVerticalFlip(p=0.5, same_on_batch=True), K.RandomAffine( degrees=180, shear=(-10, 10, -10, 10), # x_min, x_max, y_min, y_max translate=(0.05, 0.05), scale=(0.8, 1.2), # x_min, x_max, y_min, y_max p=p, same_on_batch=True, ), K.RandomBrightness( (0.5, 1.5), clip_output=False, p=p, same_on_batch=True ), K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), ] elif level == 2: # Crafted for DeepCell crop size 256 augs = [ # Augmentations for all images in a window K.RandomHorizontalFlip(p=0.5, same_on_batch=True), K.RandomVerticalFlip(p=0.5, same_on_batch=True), K.RandomAffine( degrees=180, shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max translate=(0.03, 0.03), scale=(0.8, 1.2), # isotropic p=p, same_on_batch=True, ), # Anisotropic scaling K.RandomAffine( degrees=0, scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max p=p, same_on_batch=True, ), # Independet augmentations for each image in window K.RandomAffine( degrees=3, shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max translate=(0.04, 0.04), p=p, same_on_batch=False, ), # not implemented for points in kornia 0.7.0 # K.RandomElasticTransform(alpha=50, sigma=5, p=p, same_on_batch=False), # Intensity-based augmentations K.RandomBrightness( (0.5, 1.5), clip_output=False, p=p, same_on_batch=True ), K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), ] elif level == 3: # Crafted for DeepCell crop size 256 augs = [ # Augmentations for all images in a window K.RandomHorizontalFlip(p=0.5, same_on_batch=True), K.RandomVerticalFlip(p=0.5, same_on_batch=True), ConcatAffine([ K.RandomAffine( degrees=180, shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max translate=(0.03, 0.03), scale=(0.8, 1.2), # isotropic p=p, same_on_batch=True, ), # Anisotropic scaling K.RandomAffine( degrees=0, scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max p=p, same_on_batch=True, ), ]), RandomTemporalAffine( degrees=10, translate=(0.05, 0.05), p=p, # same_on_batch=True, ), # Independet augmentations for each image in window K.RandomAffine( degrees=2, shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max translate=(0.01, 0.01), p=0.5 * p, same_on_batch=False, ), # Intensity-based augmentations RandomIntensityScaleShift( (0.5, 2.0), (-0.1, 0.1), clip_output=False, p=p, same_on_batch=True ), K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), ] elif level == 4: # debug augs = [ K.RandomAffine( degrees=30, shear=(-0, 0, -0, 0), # x_min, x_max, y_min, y_max translate=(0.0, 0.0), p=1, same_on_batch=True, ), ] else: raise ValueError(f"level {level} not supported") super().__init__(augs, filter_points)