| | import torch |
| |
|
| |
|
| | Tensor = torch.Tensor |
| | Device = torch.DeviceObjType |
| | Dtype = torch.Type |
| | pad = torch.nn.functional.pad |
| |
|
| |
|
| | def _compute_zero_padding(kernel_size: tuple[int, int] | int) -> tuple[int, int]: |
| | ky, kx = _unpack_2d_ks(kernel_size) |
| | return (ky - 1) // 2, (kx - 1) // 2 |
| |
|
| |
|
| | def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]: |
| | if isinstance(kernel_size, int): |
| | ky = kx = kernel_size |
| | else: |
| | assert len(kernel_size) == 2, '2D Kernel size should have a length of 2.' |
| | ky, kx = kernel_size |
| |
|
| | ky = int(ky) |
| | kx = int(kx) |
| | return ky, kx |
| |
|
| |
|
| | def gaussian( |
| | window_size: int, sigma: Tensor | float, *, device: Device | None = None, dtype: Dtype | None = None |
| | ) -> Tensor: |
| |
|
| | batch_size = sigma.shape[0] |
| |
|
| | x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) |
| |
|
| | if window_size % 2 == 0: |
| | x = x + 0.5 |
| |
|
| | gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) |
| |
|
| | return gauss / gauss.sum(-1, keepdim=True) |
| |
|
| |
|
| | def get_gaussian_kernel1d( |
| | kernel_size: int, |
| | sigma: float | Tensor, |
| | force_even: bool = False, |
| | *, |
| | device: Device | None = None, |
| | dtype: Dtype | None = None, |
| | ) -> Tensor: |
| |
|
| | return gaussian(kernel_size, sigma, device=device, dtype=dtype) |
| |
|
| |
|
| | def get_gaussian_kernel2d( |
| | kernel_size: tuple[int, int] | int, |
| | sigma: tuple[float, float] | Tensor, |
| | force_even: bool = False, |
| | *, |
| | device: Device | None = None, |
| | dtype: Dtype | None = None, |
| | ) -> Tensor: |
| |
|
| | sigma = torch.Tensor([[sigma, sigma]]).to(device=device, dtype=dtype) |
| |
|
| | ksize_y, ksize_x = _unpack_2d_ks(kernel_size) |
| | sigma_y, sigma_x = sigma[:, 0, None], sigma[:, 1, None] |
| |
|
| | kernel_y = get_gaussian_kernel1d(ksize_y, sigma_y, force_even, device=device, dtype=dtype)[..., None] |
| | kernel_x = get_gaussian_kernel1d(ksize_x, sigma_x, force_even, device=device, dtype=dtype)[..., None] |
| |
|
| | return kernel_y * kernel_x.view(-1, 1, ksize_x) |
| |
|
| |
|
| | def _bilateral_blur( |
| | input: Tensor, |
| | guidance: Tensor | None, |
| | kernel_size: tuple[int, int] | int, |
| | sigma_color: float | Tensor, |
| | sigma_space: tuple[float, float] | Tensor, |
| | border_type: str = 'reflect', |
| | color_distance_type: str = 'l1', |
| | ) -> Tensor: |
| |
|
| | if isinstance(sigma_color, Tensor): |
| | sigma_color = sigma_color.to(device=input.device, dtype=input.dtype).view(-1, 1, 1, 1, 1) |
| |
|
| | ky, kx = _unpack_2d_ks(kernel_size) |
| | pad_y, pad_x = _compute_zero_padding(kernel_size) |
| |
|
| | padded_input = pad(input, (pad_x, pad_x, pad_y, pad_y), mode=border_type) |
| | unfolded_input = padded_input.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) |
| |
|
| | if guidance is None: |
| | guidance = input |
| | unfolded_guidance = unfolded_input |
| | else: |
| | padded_guidance = pad(guidance, (pad_x, pad_x, pad_y, pad_y), mode=border_type) |
| | unfolded_guidance = padded_guidance.unfold(2, ky, 1).unfold(3, kx, 1).flatten(-2) |
| |
|
| | diff = unfolded_guidance - guidance.unsqueeze(-1) |
| | if color_distance_type == "l1": |
| | color_distance_sq = diff.abs().sum(1, keepdim=True).square() |
| | elif color_distance_type == "l2": |
| | color_distance_sq = diff.square().sum(1, keepdim=True) |
| | else: |
| | raise ValueError("color_distance_type only acceps l1 or l2") |
| | color_kernel = (-0.5 / sigma_color**2 * color_distance_sq).exp() |
| |
|
| | space_kernel = get_gaussian_kernel2d(kernel_size, sigma_space, device=input.device, dtype=input.dtype) |
| | space_kernel = space_kernel.view(-1, 1, 1, 1, kx * ky) |
| |
|
| | kernel = space_kernel * color_kernel |
| | out = (unfolded_input * kernel).sum(-1) / kernel.sum(-1) |
| | return out |
| |
|
| |
|
| | def bilateral_blur( |
| | input: Tensor, |
| | kernel_size: tuple[int, int] | int = (13, 13), |
| | sigma_color: float | Tensor = 3.0, |
| | sigma_space: tuple[float, float] | Tensor = 3.0, |
| | border_type: str = 'reflect', |
| | color_distance_type: str = 'l1', |
| | ) -> Tensor: |
| | return _bilateral_blur(input, None, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) |
| |
|
| |
|
| | def adaptive_anisotropic_filter(x, g=None): |
| | if g is None: |
| | g = x |
| | s, m = torch.std_mean(g, dim=(1, 2, 3), keepdim=True) |
| | s = s + 1e-5 |
| | guidance = (g - m) / s |
| | y = _bilateral_blur(x, guidance, |
| | kernel_size=(13, 13), |
| | sigma_color=3.0, |
| | sigma_space=3.0, |
| | border_type='reflect', |
| | color_distance_type='l1') |
| | return y |
| |
|
| |
|
| | def joint_bilateral_blur( |
| | input: Tensor, |
| | guidance: Tensor, |
| | kernel_size: tuple[int, int] | int, |
| | sigma_color: float | Tensor, |
| | sigma_space: tuple[float, float] | Tensor, |
| | border_type: str = 'reflect', |
| | color_distance_type: str = 'l1', |
| | ) -> Tensor: |
| | return _bilateral_blur(input, guidance, kernel_size, sigma_color, sigma_space, border_type, color_distance_type) |
| |
|
| |
|
| | class _BilateralBlur(torch.nn.Module): |
| | def __init__( |
| | self, |
| | kernel_size: tuple[int, int] | int, |
| | sigma_color: float | Tensor, |
| | sigma_space: tuple[float, float] | Tensor, |
| | border_type: str = 'reflect', |
| | color_distance_type: str = "l1", |
| | ) -> None: |
| | super().__init__() |
| | self.kernel_size = kernel_size |
| | self.sigma_color = sigma_color |
| | self.sigma_space = sigma_space |
| | self.border_type = border_type |
| | self.color_distance_type = color_distance_type |
| |
|
| | def __repr__(self) -> str: |
| | return ( |
| | f"{self.__class__.__name__}" |
| | f"(kernel_size={self.kernel_size}, " |
| | f"sigma_color={self.sigma_color}, " |
| | f"sigma_space={self.sigma_space}, " |
| | f"border_type={self.border_type}, " |
| | f"color_distance_type={self.color_distance_type})" |
| | ) |
| |
|
| |
|
| | class BilateralBlur(_BilateralBlur): |
| | def forward(self, input: Tensor) -> Tensor: |
| | return bilateral_blur( |
| | input, self.kernel_size, self.sigma_color, self.sigma_space, self.border_type, self.color_distance_type |
| | ) |
| |
|
| |
|
| | class JointBilateralBlur(_BilateralBlur): |
| | def forward(self, input: Tensor, guidance: Tensor) -> Tensor: |
| | return joint_bilateral_blur( |
| | input, |
| | guidance, |
| | self.kernel_size, |
| | self.sigma_color, |
| | self.sigma_space, |
| | self.border_type, |
| | self.color_distance_type, |
| | ) |
| |
|