hbyecoding's picture
Upload 143 files
b2c5353 verified
# This code was originally written by Jose Javier Gonzalez Ortiz
# for use in UniverSeg (https://github.com/JJGO/UniverSeg).
# It is included here with their permission, without modifications.
import random
from typing import Any, Optional, Union
import kornia as K
import kornia.augmentation as KA
import numpy as np
import torch
from kornia.constants import BorderType
from pydantic import validate_arguments
from .common import AugmentationBase2D, _as2tuple, _as_single_val
class RandomBrightnessContrast(AugmentationBase2D):
def __init__(
self,
brightness: Union[float, tuple[float, float]] = 0.0,
contrast: Union[float, tuple[float, float]] = 1.0,
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
p=p,
same_on_batch=same_on_batch,
p_batch=1.0,
keepdim=keepdim,
)
self.brightness = brightness
self.contrast = contrast
def generate_parameters(self, input_shape: torch.Size):
brightness = _as_single_val(self.brightness)
contrast = _as_single_val(self.contrast)
order = np.random.permutation(2)
return dict(brightness=brightness, contrast=contrast, order=order)
def apply_transform(
self,
input: torch.Tensor,
params: dict[str, torch.Tensor],
flags: dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
transforms = [
lambda img: K.enhance.adjust_brightness(img, params["brightness"]),
lambda img: K.enhance.adjust_contrast(img, params["contrast"]),
]
jittered = input
for idx in params["order"].tolist():
t = transforms[idx]
jittered = t(jittered)
return jittered
class FilterBase(AugmentationBase2D):
@validate_arguments
def __init__(
self,
kernel_size: Union[int, tuple[int, int]],
sigma: Union[float, tuple[float, float]],
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
p=p,
same_on_batch=same_on_batch,
p_batch=1.0,
keepdim=keepdim,
)
self.kernel_size = kernel_size
self.sigma = sigma
class VariableFilterBase(FilterBase):
"""Helper class for tasks that involve a random filter"""
def generate_parameters(self, input_shape: torch.Size):
kernel_size = _as_single_val(self.kernel_size)
sigma = _as_single_val(self.sigma)
return dict(kernel_size=kernel_size, sigma=sigma)
class RandomVariableGaussianBlur(VariableFilterBase):
def __init__(
self,
kernel_size: Union[int, tuple[int, int]],
sigma: Union[float, tuple[float, float]],
border_type: str = "reflect",
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
kernel_size=kernel_size,
sigma=sigma,
p=p,
same_on_batch=same_on_batch,
keepdim=keepdim,
)
self.flags = dict(border_type=BorderType.get(border_type))
def apply_transform(
self,
input: torch.Tensor,
params: dict[str, torch.Tensor],
flags: dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
kernel_size = _as2tuple(self.kernel_size)
sigma = _as2tuple(self.sigma)
return K.filters.gaussian_blur2d(
input, kernel_size, sigma, flags["border_type"].name.lower()
)
class RandomVariableBoxBlur(AugmentationBase2D):
def __init__(
self,
kernel_size: Union[int, tuple[int, int]] = 3,
border_type: str = "reflect",
normalized: bool = True,
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
p=p,
same_on_batch=same_on_batch,
p_batch=1.0,
keepdim=keepdim,
)
self.flags = dict(border_type=border_type, normalized=normalized)
def generate_parameters(self, input_shape: torch.Size):
kernel_size = _as_single_val(self.kernel_size)
return dict(kernel_size=kernel_size)
def apply_transform(
self,
input: torch.Tensor,
params: dict[str, torch.Tensor],
flags: dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
kernel_size = _as2tuple(params["kernel_size"])
return K.filters.box_blur(
input, kernel_size, flags["border_type"], flags["normalized"]
)
class RandomVariableGaussianNoise(AugmentationBase2D):
def __init__(
self,
mean: Union[float, tuple[float, float]] = 0.0,
std: Union[float, tuple[float, float]] = 1.0,
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
p=p,
same_on_batch=same_on_batch,
p_batch=1.0,
keepdim=keepdim,
)
self.mean = mean
self.std = std
def generate_parameters(self, input_shape: torch.Size):
mean = _as_single_val(self.mean)
std = _as_single_val(self.std)
if torch.cuda.is_available():
noise = torch.empty(input_shape, dtype=torch.float32, device='cuda').normal_(mean, std)
else:
noise = torch.empty(input_shape, dtype=torch.float32, device='cpu').normal_(mean, std)
return dict(noise=noise)
def apply_transform(
self,
input: torch.Tensor,
params: dict[str, torch.Tensor],
flags: dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return input + params["noise"].to(input)
def validate_elastic_sigma_alpha(sigma, alpha):
if isinstance(alpha, (tuple, list)):
alpha = max(alpha)
if isinstance(sigma, (tuple, list)):
sigma = max(sigma)
if sigma / alpha < 1:
raise ValueError("Alpha and Sigma seem to be swapped")
class RandomVariableElasticTransform(AugmentationBase2D):
def __init__(
self,
kernel_size: Union[int, tuple[int, int]] = 63,
sigma: Union[float, tuple[float, float]] = 32,
alpha: Union[float, tuple[float, float]] = 1.0,
align_corners: bool = False,
mode: str = "bilinear",
padding_mode: str = "zeros",
same_on_batch: bool = False,
p: float = 0.5,
keepdim: bool = False,
) -> None:
super().__init__(
p=p,
same_on_batch=same_on_batch,
p_batch=1.0,
keepdim=keepdim,
)
validate_elastic_sigma_alpha(sigma, alpha)
self.flags = dict(
kernel_size=kernel_size,
sigma=sigma,
alpha=alpha,
align_corners=align_corners,
mode=mode,
padding_mode=padding_mode,
)
def generate_parameters(self, shape: torch.Size) -> dict[str, torch.Tensor]:
B, _, H, W = shape
# By default self.device (which is what kornia prefers, it default to cpu) so
# the conv2d's to lowpass filter the noise happen on the cpu regardless of
# input.device value. To bypass this, we force the noise device to 'cuda'
# whenever possible
device = "cuda" if torch.cuda.is_available() else "cpu"
if self.same_on_batch:
noise = torch.rand(1, 2, H, W, device=device, dtype=self.dtype).repeat(
B, 1, 1, 1
)
else:
noise = torch.rand(B, 2, H, W, device=device, dtype=self.dtype)
kernel_size = _as_single_val(self.flags["kernel_size"])
sigma = _as_single_val(self.flags["sigma"])
alpha = _as_single_val(self.flags["alpha"])
return dict(
noise=noise * 2 - 1, kernel_size=kernel_size, sigma=sigma, alpha=alpha
)
def apply_transform(
self,
input: torch.Tensor,
params: dict[str, torch.Tensor],
flags: dict[str, Any],
transform: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert (
input.device == params["noise"].device
), f"Input/Noise with different devices {input.device} & {params['noise'].device}"
return K.geometry.transform.elastic_transform2d(
input,
params["noise"], # .to(input),
_as2tuple(params["kernel_size"]),
_as2tuple(params["sigma"]),
_as2tuple(params["alpha"]),
flags["align_corners"],
flags["mode"],
flags["padding_mode"],
)