Shengxiao0709's picture
Upload 78 files
8f72b1f verified
"""#TODO: dont convert to numpy and back to torch."""
from collections.abc import Iterable, Sequence
from itertools import chain
from typing import Any
import kornia.augmentation as K
import numpy as np
import torch
from kornia.augmentation import random_generator as rg
from kornia.augmentation.utils import _range_bound
from kornia.constants import DataKey, Resample
from typing import Optional, Tuple, Sequence, Dict, Union
def default_augmenter(coords: np.ndarray):
# TODO parametrize magnitude of different augmentations
ndim = coords.shape[1]
assert coords.ndim == 2 and ndim in (2, 3)
# first remove offset
center = coords.mean(axis=0, keepdims=True)
coords = coords - center
# apply random flip
coords *= 2 * np.random.randint(0, 2, (1, ndim)) - 1
# apply rotation along the last two dimensions
phi = np.random.uniform(0, 2 * np.pi)
coords = _rotate(coords, phi, center=None)
if ndim == 3:
# rotate along the first two dimensions too
phi2, phi3 = np.random.uniform(0, 2 * np.pi, 2)
coords = _rotate(coords, phi2, rot_axis=(0, 1), center=None)
coords = _rotate(coords, phi3, rot_axis=(0, 2), center=None)
coords += center
# translation
trans = 128 * np.random.uniform(-1, 1, (1, ndim))
coords += trans
# elastic
coords += 1.5 * np.random.normal(0, 1, coords.shape)
return coords
def _rotate(
coords: np.ndarray, phi: float, rot_axis=(-2, -1), center: Optional[Tuple] = None
):
"""Rotation along the last two dimensions of coords[..,:-2:]."""
ndim = coords.shape[1]
assert coords.ndim == 2 and ndim in (2, 3)
if center is None:
center = (0,) * ndim
assert len(center) == ndim
center = np.asarray(center)
co, si = np.cos(phi), np.sin(phi)
Rot = np.eye(ndim)
Rot[np.ix_(rot_axis, rot_axis)] = np.array(((co, -si), (si, co)))
x = coords - center
x = x @ Rot.T
x += center
return x
def _filter_points(
points: np.ndarray, shape: tuple, origin: Optional[Tuple] = None
) -> np.ndarray:
"""Returns indices of points that are inside the shape extent and given origin."""
ndim = points.shape[-1]
if origin is None:
origin = (0,) * ndim
idx = tuple(
np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
for i in range(ndim)
)
idx = np.where(np.all(idx, axis=0))[0]
return idx
class ConcatAffine(K.RandomAffine):
"""Concatenate multiple affine transformations without intermediates."""
def __init__(self, affines: Sequence[K.RandomAffine]):
super().__init__(degrees=0)
self._affines = affines
if not all([a.same_on_batch for a in affines]):
raise ValueError("all affines must have same_on_batch=True")
def merge_params(self, params: Sequence[Dict[str, torch.Tensor]]):
"""Merge params from affines."""
out = params[0].copy()
def _torchmax(x, dim):
return torch.max(x, dim=dim).values
ops = {
"translations": torch.sum,
"center": torch.mean,
"scale": torch.prod,
"shear_x": torch.sum,
"shear_y": torch.sum,
"angle": torch.sum,
"batch_prob": _torchmax,
}
for k, v in params[0].items():
ps = [p[k] for p in params if len(p[k]) > 0]
if len(ps) > 0 and k in ops:
v_new = torch.stack(ps, dim=0).float()
v_new = ops[k](v_new, dim=0)
v_new = v_new.to(v.dtype)
else:
v_new = v
out[k] = v_new
return out
def forward_parameters(
self, batch_shape: Tuple[int, ...]
) -> Dict[str, torch.Tensor]:
params = tuple(a.forward_parameters(batch_shape) for a in self._affines)
# print(params)
return self.merge_params(params)
# custom augmentations
class RandomIntensityScaleShift(K.IntensityAugmentationBase2D):
r"""Apply a random scale and shift to the image intensity.
Args:
p: probability of applying the transformation.
scale: the scale factor to apply
shift: the offset to apply
clip_output: if true clip output
same_on_batch: apply the same transformation across the batch.
keepdim: whether to keep the output shape the same as input (True) or broadcast it
to the batch form (False).
Shape:
- Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)`
- Output: :math:`(B, C, H, W)`
.. note::
This function internally uses :func:`kornia.enhance.adjust_brightness`
"""
def __init__(
self,
scale: Tuple[float, float] = (0.5, 2.0),
shift: Tuple[float, float] = (-0.1, 0.1),
clip_output: bool = True,
same_on_batch: bool = False,
p: float = 1.0,
keepdim: bool = False,
) -> None:
super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim)
self.scale = _range_bound(
scale, "scale", center=0, bounds=(-float("inf"), float("inf"))
)
self.shift = _range_bound(
shift, "shift", center=0, bounds=(-float("inf"), float("inf"))
)
self._param_generator = rg.PlainUniformGenerator(
(self.scale, "scale_factor", None, None),
(self.shift, "shift_factor", None, None),
)
self.clip_output = clip_output
def apply_transform(
self,
input: torch.Tensor,
params: Dict[str, torch.Tensor],
flags: Dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
scale_factor = params["scale_factor"].to(input)
shift_factor = params["shift_factor"].to(input)
scale_factor = scale_factor.view(len(scale_factor), 1, 1, 1)
shift_factor = shift_factor.view(len(scale_factor), 1, 1, 1)
img_adjust = input * scale_factor + shift_factor
if self.clip_output:
img_adjust = img_adjust.clamp(min=0.0, max=1.0)
return img_adjust
class RandomTemporalAffine(K.RandomAffine):
r"""Apply a random 2D affine transformation to a batch of images while
varying the transformation across the time dimension from 0 to 1.
Same args/kwargs as K.RandomAffine
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, same_on_batch=True, **kwargs)
def forward_parameters(
self, batch_shape: Tuple[int, ...]
) -> Dict[str, torch.Tensor]:
params = super().forward_parameters(batch_shape)
factor = torch.linspace(0, 1, batch_shape[0]).to(params["translations"])
for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
v = params[key]
if len(v) > 0:
params[key] = v * factor.view(*((-1,) + (1,) * len(v.shape[1:])))
for key in [
"scale",
]:
v = params[key]
if len(v) > 0:
params[key] = 1 + (v - 1) * factor.view(
*((-1,) + (1,) * len(v.shape[1:]))
)
return params
# def compute_transformation(self, input: torch.Tensor,
# params: Dict[str, torch.Tensor],
# flags: Dict[str, Any]) -> torch.Tensor:
# factor = torch.linspace(0, 1, input.shape[0]).to(input)
# for key in ["translations", "center", "angle", "shear_x", "shear_y"]:
# v = params[key]
# params[key] = v * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
# for key in ["scale", ]:
# v = params[key]
# params[key] = 1 + (v-1) * factor.view(*((-1,)+(1,)*len(v.shape[1:])))
# return super().compute_transformation(input, params, flags)
class BasicPipeline:
"""transforms img, mask, and points.
Only supports 2D transformations for now (any 3D object will preserve its z coordinates/dimensions)
"""
def __init__(self, augs: tuple, filter_points: bool = True):
self.data_keys = ("input", "mask", "keypoints")
self.pipeline = K.AugmentationSequential(
*augs,
# disable align_corners to not trigger lots of warnings from kornia
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": False}
},
data_keys=self.data_keys,
)
self.filter_points = filter_points
def __call__(
self,
img: np.ndarray,
mask: np.ndarray,
points: np.ndarray,
timepoints: np.ndarray,
):
ndim = img.ndim - 1
assert (
ndim in (2, 3)
and points.ndim == 2
and points.shape[-1] == ndim
and timepoints.ndim == 1
and img.shape == mask.shape
)
x = torch.from_numpy(img).float()
y = torch.from_numpy(mask.astype(np.int64)).float()
# if 2D add dummy channel
if ndim == 2:
x = x.unsqueeze(1)
y = y.unsqueeze(1)
p = points[..., [1, 0]]
# if 3D we use z as channel (i.e. fix augs across z)
elif ndim == 3:
p = points[..., [2, 1]]
# flip as kornia expects xy and not yx
p = torch.from_numpy(p).unsqueeze(0).float()
# add batch by duplicating to make kornia happy
p = p.expand(len(x), -1, -1)
# create a mask to know which timepoint the points belong to
ts = torch.from_numpy(timepoints).long()
n_points = p.shape[1]
if n_points > 0:
x, y, p = self.pipeline(x, y, p)
else:
# dummy keypoints
x, y = self.pipeline(x, y, torch.zeros((len(x), 1, 2)))[:2]
# remove batch
p = p[ts, torch.arange(n_points)]
# flip back
p = p[..., [1, 0]]
# remove channel
if ndim == 2:
x = x.squeeze(1)
y = y.squeeze(1)
x = x.numpy()
y = y.numpy().astype(np.uint16)
# p = p.squeeze(0).numpy()
p = p.numpy()
# add back z coordinates
if ndim == 3:
p = np.concatenate([points[..., 0:1], p], axis=-1)
ts = ts.numpy()
# remove points outside of img/mask
if self.filter_points:
idx = _filter_points(p, shape=x.shape[-ndim:])
else:
idx = np.arange(len(p), dtype=int)
p = p[idx]
return (x, y, p), idx
class RandomCrop:
def __init__(
self,
crop_size: Optional[Union[int, Tuple[int]]] = None,
ndim: int = 2,
ensure_inside_points: bool = False,
use_padding: bool = True,
padding_mode="constant",
) -> None:
"""crop_size: tuple of int
can be tuple of length 1 (all dimensions)
of length ndim (y,x,...)
of length 2*ndim (y1,y2, x1,x2, ...).
"""
if isinstance(crop_size, int):
crop_size = (crop_size,) * 2 * ndim
elif isinstance(crop_size, Iterable):
pass
else:
raise ValueError(f"{crop_size} has to be int or tuple of int")
if len(crop_size) == 1:
crop_size = (crop_size[0],) * 2 * ndim
elif len(crop_size) == ndim:
crop_size = tuple(chain(*tuple((c, c) for c in crop_size)))
elif len(crop_size) == 2 * ndim:
pass
else:
raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
crop_size = np.array(crop_size)
self._ndim = ndim
self._crop_bounds = crop_size[::2], crop_size[1::2]
self._use_padding = use_padding
self._ensure_inside_points = ensure_inside_points
self._rng = np.random.RandomState()
self._padding_mode = padding_mode
def crop_img(self, img: np.ndarray, corner: np.ndarray, crop_size: np.ndarray):
if not img.ndim == self._ndim + 1:
raise ValueError(
f"img has to be 1 (time) + {self._ndim} spatial dimensions"
)
pad_left = np.maximum(0, -corner)
pad_right = np.maximum(
0, corner + crop_size - np.array(img.shape[-self._ndim :])
)
img = np.pad(
img,
((0, 0), *tuple(np.stack((pad_left, pad_right)).T)),
mode=self._padding_mode,
)
slices = (
slice(None),
*tuple(slice(c, c + s) for c, s in zip(corner + pad_left, crop_size)),
)
return img[slices]
def crop_points(
self, points: np.ndarray, corner: np.ndarray, crop_size: np.ndarray
):
idx = _filter_points(points, shape=crop_size, origin=corner)
return points[idx] - corner, idx
def __call__(self, img: np.ndarray, mask: np.ndarray, points: np.ndarray):
assert (
img.ndim == self._ndim + 1
and points.ndim == 2
and points.shape[-1] == self._ndim
and img.shape == mask.shape
)
points = points.astype(int)
crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
# print(f'{crop_size=}')
if self._ensure_inside_points:
if len(points) == 0:
print("No points given, cannot ensure inside points")
return (img, mask, points), np.zeros((0,), int)
# sample point and corner relative to it
_idx = np.random.randint(len(points))
corner = (
points[_idx]
- crop_size
+ 1
+ self._rng.randint(crop_size // 4, 3 * crop_size // 4)
)
else:
corner = self._rng.randint(
0, np.maximum(1, np.array(img.shape[-self._ndim :]) - crop_size)
)
if not self._use_padding:
corner = np.maximum(0, corner)
crop_size = np.minimum(
crop_size, np.array(img.shape[-self._ndim :]) - corner
)
img = self.crop_img(img, corner, crop_size)
mask = self.crop_img(mask, corner, crop_size)
points, idx = self.crop_points(points, corner, crop_size)
return (img, mask, points), idx
class AugmentationPipeline(BasicPipeline):
"""transforms img, mask, and points."""
def __init__(self, p=0.5, filter_points=True, level=1):
if level == 1:
augs = [
# Augmentations for all images in a window
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
K.RandomAffine(
degrees=180,
shear=(-10, 10, -10, 10), # x_min, x_max, y_min, y_max
translate=(0.05, 0.05),
scale=(0.8, 1.2), # x_min, x_max, y_min, y_max
p=p,
same_on_batch=True,
),
K.RandomBrightness(
(0.5, 1.5), clip_output=False, p=p, same_on_batch=True
),
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
]
elif level == 2:
# Crafted for DeepCell crop size 256
augs = [
# Augmentations for all images in a window
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
K.RandomAffine(
degrees=180,
shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
translate=(0.03, 0.03),
scale=(0.8, 1.2), # isotropic
p=p,
same_on_batch=True,
),
# Anisotropic scaling
K.RandomAffine(
degrees=0,
scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
p=p,
same_on_batch=True,
),
# Independet augmentations for each image in window
K.RandomAffine(
degrees=3,
shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
translate=(0.04, 0.04),
p=p,
same_on_batch=False,
),
# not implemented for points in kornia 0.7.0
# K.RandomElasticTransform(alpha=50, sigma=5, p=p, same_on_batch=False),
# Intensity-based augmentations
K.RandomBrightness(
(0.5, 1.5), clip_output=False, p=p, same_on_batch=True
),
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
]
elif level == 3:
# Crafted for DeepCell crop size 256
augs = [
# Augmentations for all images in a window
K.RandomHorizontalFlip(p=0.5, same_on_batch=True),
K.RandomVerticalFlip(p=0.5, same_on_batch=True),
ConcatAffine([
K.RandomAffine(
degrees=180,
shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max
translate=(0.03, 0.03),
scale=(0.8, 1.2), # isotropic
p=p,
same_on_batch=True,
),
# Anisotropic scaling
K.RandomAffine(
degrees=0,
scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max
p=p,
same_on_batch=True,
),
]),
RandomTemporalAffine(
degrees=10,
translate=(0.05, 0.05),
p=p,
# same_on_batch=True,
),
# Independet augmentations for each image in window
K.RandomAffine(
degrees=2,
shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max
translate=(0.01, 0.01),
p=0.5 * p,
same_on_batch=False,
),
# Intensity-based augmentations
RandomIntensityScaleShift(
(0.5, 2.0), (-0.1, 0.1), clip_output=False, p=p, same_on_batch=True
),
K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False),
]
elif level == 4:
# debug
augs = [
K.RandomAffine(
degrees=30,
shear=(-0, 0, -0, 0), # x_min, x_max, y_min, y_max
translate=(0.0, 0.0),
p=1,
same_on_batch=True,
),
]
else:
raise ValueError(f"level {level} not supported")
super().__init__(augs, filter_points)