| from typing import List |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from .__tmp__ import _deprecation_wrapper |
| from .kernels import normalize_kernel2d |
|
|
|
|
| def _compute_padding(kernel_size: List[int]) -> List[int]: |
| """Compute padding tuple.""" |
| |
| |
| if len(kernel_size) < 2: |
| raise AssertionError(kernel_size) |
| computed = [k // 2 for k in kernel_size] |
|
|
| |
|
|
| out_padding = 2 * len(kernel_size) * [0] |
|
|
| for i in range(len(kernel_size)): |
| computed_tmp = computed[-(i + 1)] |
| if kernel_size[i] % 2 == 0: |
| padding = computed_tmp - 1 |
| else: |
| padding = computed_tmp |
| out_padding[2 * i + 0] = padding |
| out_padding[2 * i + 1] = computed_tmp |
| return out_padding |
|
|
|
|
| def filter2d( |
| input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'reflect', normalized: bool = False, |
| padding: str = 'same' |
| ) -> torch.Tensor: |
| r"""Convolve a tensor with a 2d kernel. |
| |
| The function applies a given kernel to a tensor. The kernel is applied |
| independently at each depth channel of the tensor. Before applying the |
| kernel, the function applies padding according to the specified mode so |
| that the output remains in the same shape. |
| |
| Args: |
| input: the input tensor with shape of |
| :math:`(B, C, H, W)`. |
| kernel: the kernel to be convolved with the input |
| tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. |
| border_type: the padding mode to be applied before convolving. |
| The expected modes are: ``'constant'``, ``'reflect'``, |
| ``'replicate'`` or ``'circular'``. |
| normalized: If True, kernel will be L1 normalized. |
| padding: This defines the type of padding. |
| 2 modes available ``'same'`` or ``'valid'``. |
| |
| Return: |
| torch.Tensor: the convolved tensor of same size and numbers of channels |
| as the input with shape :math:`(B, C, H, W)`. |
| |
| Example: |
| >>> input = torch.tensor([[[ |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 5., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.],]]]) |
| >>> kernel = torch.ones(1, 3, 3) |
| >>> filter2d(input, kernel, padding='same') |
| tensor([[[[0., 0., 0., 0., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 0., 0., 0., 0.]]]]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") |
|
|
| if not isinstance(kernel, torch.Tensor): |
| raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") |
|
|
| if not isinstance(border_type, str): |
| raise TypeError(f"Input border_type is not string. Got {type(border_type)}") |
|
|
| if border_type not in ['constant', 'reflect', 'replicate', 'circular']: |
| raise ValueError(f"Invalid border type, we expect 'constant', \ |
| 'reflect', 'replicate', 'circular'. Got:{border_type}") |
|
|
| if not isinstance(padding, str): |
| raise TypeError(f"Input padding is not string. Got {type(padding)}") |
|
|
| if padding not in ['valid', 'same']: |
| raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}") |
|
|
| if not len(input.shape) == 4: |
| raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") |
|
|
| if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])): |
| raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}") |
|
|
| |
| b, c, h, w = input.shape |
| tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) |
|
|
| if normalized: |
| tmp_kernel = normalize_kernel2d(tmp_kernel) |
|
|
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) |
|
|
| height, width = tmp_kernel.shape[-2:] |
|
|
| |
| if padding == 'same': |
| padding_shape: List[int] = _compute_padding([height, width]) |
| input = F.pad(input, padding_shape, mode=border_type) |
|
|
| |
| tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) |
| input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) |
|
|
| |
| output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) |
|
|
| if padding == 'same': |
| out = output.view(b, c, h, w) |
| else: |
| out = output.view(b, c, h - height + 1, w - width + 1) |
|
|
| return out |
|
|
|
|
| def filter2d_separable(input: torch.Tensor, |
| kernel_x: torch.Tensor, |
| kernel_y: torch.Tensor, |
| border_type: str = 'reflect', |
| normalized: bool = False, |
| padding: str = 'same') -> torch.Tensor: |
| r"""Convolve a tensor with two 1d kernels, in x and y directions. |
| |
| The function applies a given kernel to a tensor. The kernel is applied |
| independently at each depth channel of the tensor. Before applying the |
| kernel, the function applies padding according to the specified mode so |
| that the output remains in the same shape. |
| |
| Args: |
| input: the input tensor with shape of |
| :math:`(B, C, H, W)`. |
| kernel_x: the kernel to be convolved with the input |
| tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. |
| kernel_y: the kernel to be convolved with the input |
| tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. |
| border_type: the padding mode to be applied before convolving. |
| The expected modes are: ``'constant'``, ``'reflect'``, |
| ``'replicate'`` or ``'circular'``. |
| normalized: If True, kernel will be L1 normalized. |
| padding: This defines the type of padding. |
| 2 modes available ``'same'`` or ``'valid'``. |
| |
| Return: |
| torch.Tensor: the convolved tensor of same size and numbers of channels |
| as the input with shape :math:`(B, C, H, W)`. |
| |
| Example: |
| >>> input = torch.tensor([[[ |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 5., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.],]]]) |
| >>> kernel = torch.ones(1, 3) |
| |
| >>> filter2d_separable(input, kernel, kernel, padding='same') |
| tensor([[[[0., 0., 0., 0., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 0., 0., 0., 0.]]]]) |
| """ |
| out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) |
| out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) |
| return out |
|
|
|
|
| def filter3d( |
| input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False |
| ) -> torch.Tensor: |
| r"""Convolve a tensor with a 3d kernel. |
| |
| The function applies a given kernel to a tensor. The kernel is applied |
| independently at each depth channel of the tensor. Before applying the |
| kernel, the function applies padding according to the specified mode so |
| that the output remains in the same shape. |
| |
| Args: |
| input: the input tensor with shape of |
| :math:`(B, C, D, H, W)`. |
| kernel: the kernel to be convolved with the input |
| tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. |
| border_type: the padding mode to be applied before convolving. |
| The expected modes are: ``'constant'``, |
| ``'replicate'`` or ``'circular'``. |
| normalized: If True, kernel will be L1 normalized. |
| |
| Return: |
| the convolved tensor of same size and numbers of channels |
| as the input with shape :math:`(B, C, D, H, W)`. |
| |
| Example: |
| >>> input = torch.tensor([[[ |
| ... [[0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.]], |
| ... [[0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 5., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.]], |
| ... [[0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.], |
| ... [0., 0., 0., 0., 0.]] |
| ... ]]]) |
| >>> kernel = torch.ones(1, 3, 3, 3) |
| >>> filter3d(input, kernel) |
| tensor([[[[[0., 0., 0., 0., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 0., 0., 0., 0.]], |
| <BLANKLINE> |
| [[0., 0., 0., 0., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 0., 0., 0., 0.]], |
| <BLANKLINE> |
| [[0., 0., 0., 0., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 5., 5., 5., 0.], |
| [0., 0., 0., 0., 0.]]]]]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") |
|
|
| if not isinstance(kernel, torch.Tensor): |
| raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") |
|
|
| if not isinstance(border_type, str): |
| raise TypeError(f"Input border_type is not string. Got {type(kernel)}") |
|
|
| if not len(input.shape) == 5: |
| raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") |
|
|
| if not len(kernel.shape) == 4 and kernel.shape[0] != 1: |
| raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}") |
|
|
| |
| b, c, d, h, w = input.shape |
| tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) |
|
|
| if normalized: |
| bk, dk, hk, wk = kernel.shape |
| tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel) |
|
|
| tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) |
|
|
| |
| depth, height, width = tmp_kernel.shape[-3:] |
| padding_shape: List[int] = _compute_padding([depth, height, width]) |
| input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) |
|
|
| |
| tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) |
| input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) |
|
|
| |
| output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) |
|
|
| return output.view(b, c, d, h, w) |
|
|
|
|
| |
| filter2D = _deprecation_wrapper(filter2d, 'filter2D') |
| filter3D = _deprecation_wrapper(filter3d, 'filter3D') |
|
|