| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d |
|
|
|
|
| def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor: |
| r"""Compute the first order image derivative in both x and y using a Sobel |
| operator. |
| |
| .. image:: _static/img/spatial_gradient.png |
| |
| Args: |
| input: input image tensor with shape :math:`(B, C, H, W)`. |
| mode: derivatives modality, can be: `sobel` or `diff`. |
| order: the order of the derivatives. |
| normalized: whether the output is normalized. |
| |
| Return: |
| the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. |
| |
| .. note:: |
| See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ |
| filtering_edges.html>`__. |
| |
| Examples: |
| >>> input = torch.rand(1, 3, 4, 4) |
| >>> output = spatial_gradient(input) # 1x3x2x4x4 |
| >>> output.shape |
| torch.Size([1, 3, 2, 4, 4]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
| if not len(input.shape) == 4: |
| raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") |
| |
| kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) |
| if normalized: |
| kernel = normalize_kernel2d(kernel) |
|
|
| |
| b, c, h, w = input.shape |
| tmp_kernel: torch.Tensor = kernel.to(input).detach() |
| tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) |
|
|
| |
| kernel_flip: torch.Tensor = tmp_kernel.flip(-3) |
|
|
| |
| spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] |
| out_channels: int = 3 if order == 2 else 2 |
| padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None] |
|
|
| return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) |
|
|
|
|
| def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor: |
| r"""Compute the first and second order volume derivative in x, y and d using a diff |
| operator. |
| |
| Args: |
| input: input features tensor with shape :math:`(B, C, D, H, W)`. |
| mode: derivatives modality, can be: `sobel` or `diff`. |
| order: the order of the derivatives. |
| |
| Return: |
| the spatial gradients of the input feature map. |
| |
| Shape: |
| - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. |
| - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` |
| |
| Examples: |
| >>> input = torch.rand(1, 4, 2, 4, 4) |
| >>> output = spatial_gradient3d(input) |
| >>> output.shape |
| torch.Size([1, 4, 3, 2, 4, 4]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
| if not len(input.shape) == 5: |
| raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") |
| b, c, d, h, w = input.shape |
| dev = input.device |
| dtype = input.dtype |
| if (mode == 'diff') and (order == 1): |
| |
| x: torch.Tensor = F.pad(input, 6 * [1], 'replicate') |
| center = slice(1, -1) |
| left = slice(0, -2) |
| right = slice(2, None) |
| out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) |
| out[..., 0, :, :, :] = (x[..., center, center, right] - x[..., center, center, left]) |
| out[..., 1, :, :, :] = (x[..., center, right, center] - x[..., center, left, center]) |
| out[..., 2, :, :, :] = (x[..., right, center, center] - x[..., left, center, center]) |
| out = 0.5 * out |
| else: |
| |
| |
| kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) |
|
|
| tmp_kernel: torch.Tensor = kernel.to(input).detach() |
| tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) |
|
|
| |
| kernel_flip: torch.Tensor = tmp_kernel.flip(-3) |
|
|
| |
| spatial_pad = [kernel.size(2) // 2, |
| kernel.size(2) // 2, |
| kernel.size(3) // 2, |
| kernel.size(3) // 2, |
| kernel.size(4) // 2, |
| kernel.size(4) // 2] |
| out_ch: int = 6 if order == 2 else 3 |
| out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), |
| kernel_flip, |
| padding=0, |
| groups=c).view(b, c, out_ch, d, h, w) |
| return out |
|
|
|
|
| def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor: |
| r"""Compute the Sobel operator and returns the magnitude per channel. |
| |
| .. image:: _static/img/sobel.png |
| |
| Args: |
| input: the input image with shape :math:`(B,C,H,W)`. |
| normalized: if True, L1 norm of the kernel is set to 1. |
| eps: regularization number to avoid NaN during backprop. |
| |
| Return: |
| the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. |
| |
| .. note:: |
| See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ |
| filtering_edges.html>`__. |
| |
| Example: |
| >>> input = torch.rand(1, 3, 4, 4) |
| >>> output = sobel(input) # 1x3x4x4 |
| >>> output.shape |
| torch.Size([1, 3, 4, 4]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
| if not len(input.shape) == 4: |
| raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") |
|
|
| |
| edges: torch.Tensor = spatial_gradient(input, normalized=normalized) |
|
|
| |
| gx: torch.Tensor = edges[:, :, 0] |
| gy: torch.Tensor = edges[:, :, 1] |
|
|
| |
| magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) |
|
|
| return magnitude |
|
|
|
|
| class SpatialGradient(nn.Module): |
| r"""Compute the first order image derivative in both x and y using a Sobel |
| operator. |
| |
| Args: |
| mode: derivatives modality, can be: `sobel` or `diff`. |
| order: the order of the derivatives. |
| normalized: whether the output is normalized. |
| |
| Return: |
| the sobel edges of the input feature map. |
| |
| Shape: |
| - Input: :math:`(B, C, H, W)` |
| - Output: :math:`(B, C, 2, H, W)` |
| |
| Examples: |
| >>> input = torch.rand(1, 3, 4, 4) |
| >>> output = SpatialGradient()(input) # 1x3x2x4x4 |
| """ |
|
|
| def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None: |
| super().__init__() |
| self.normalized: bool = normalized |
| self.order: int = order |
| self.mode: str = mode |
|
|
| def __repr__(self) -> str: |
| return ( |
| self.__class__.__name__ + '(' |
| 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')' |
| ) |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return spatial_gradient(input, self.mode, self.order, self.normalized) |
|
|
|
|
| class SpatialGradient3d(nn.Module): |
| r"""Compute the first and second order volume derivative in x, y and d using a diff |
| operator. |
| |
| Args: |
| mode: derivatives modality, can be: `sobel` or `diff`. |
| order: the order of the derivatives. |
| |
| Return: |
| the spatial gradients of the input feature map. |
| |
| Shape: |
| - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. |
| - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` |
| |
| Examples: |
| >>> input = torch.rand(1, 4, 2, 4, 4) |
| >>> output = SpatialGradient3d()(input) |
| >>> output.shape |
| torch.Size([1, 4, 3, 2, 4, 4]) |
| """ |
|
|
| def __init__(self, mode: str = 'diff', order: int = 1) -> None: |
| super().__init__() |
| self.order: int = order |
| self.mode: str = mode |
| self.kernel = get_spatial_gradient_kernel3d(mode, order) |
| return |
|
|
| def __repr__(self) -> str: |
| return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')' |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return spatial_gradient3d(input, self.mode, self.order) |
|
|
|
|
| class Sobel(nn.Module): |
| r"""Compute the Sobel operator and returns the magnitude per channel. |
| |
| Args: |
| normalized: if True, L1 norm of the kernel is set to 1. |
| eps: regularization number to avoid NaN during backprop. |
| |
| Return: |
| the sobel edge gradient magnitudes map. |
| |
| Shape: |
| - Input: :math:`(B, C, H, W)` |
| - Output: :math:`(B, C, H, W)` |
| |
| Examples: |
| >>> input = torch.rand(1, 3, 4, 4) |
| >>> output = Sobel()(input) # 1x3x4x4 |
| """ |
|
|
| def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.normalized: bool = normalized |
| self.eps: float = eps |
|
|
| def __repr__(self) -> str: |
| return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')' |
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return sobel(input, self.normalized, self.eps) |
|
|