| import math |
| from typing import Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from kornia.color import rgb_to_grayscale |
|
|
| from .gaussian import gaussian_blur2d |
| from .kernels import get_canny_nms_kernel, get_hysteresis_kernel |
| from .sobel import spatial_gradient |
|
|
|
|
| def canny( |
| input: torch.Tensor, |
| low_threshold: float = 0.1, |
| high_threshold: float = 0.2, |
| kernel_size: Tuple[int, int] = (5, 5), |
| sigma: Tuple[float, float] = (1, 1), |
| hysteresis: bool = True, |
| eps: float = 1e-6, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| r"""Find edges of the input image and filters them using the Canny algorithm. |
| |
| .. image:: _static/img/canny.png |
| |
| Args: |
| input: input image tensor with shape :math:`(B,C,H,W)`. |
| low_threshold: lower threshold for the hysteresis procedure. |
| high_threshold: upper threshold for the hysteresis procedure. |
| kernel_size: the size of the kernel for the gaussian blur. |
| sigma: the standard deviation of the kernel for the gaussian blur. |
| hysteresis: if True, applies the hysteresis edge tracking. |
| Otherwise, the edges are divided between weak (0.5) and strong (1) edges. |
| eps: regularization number to avoid NaN during backprop. |
| |
| Returns: |
| - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. |
| - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. |
| |
| .. note:: |
| See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ |
| canny.html>`__. |
| |
| Example: |
| >>> input = torch.rand(5, 3, 4, 4) |
| >>> magnitude, edges = canny(input) # 5x3x4x4 |
| >>> magnitude.shape |
| torch.Size([5, 1, 4, 4]) |
| >>> edges.shape |
| torch.Size([5, 1, 4, 4]) |
| """ |
| if not isinstance(input, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") |
|
|
| if not len(input.shape) == 4: |
| raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") |
|
|
| if low_threshold > high_threshold: |
| raise ValueError( |
| "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format( |
| low_threshold, high_threshold |
| ) |
| ) |
|
|
| if low_threshold < 0 and low_threshold > 1: |
| raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") |
|
|
| if high_threshold < 0 and high_threshold > 1: |
| raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") |
|
|
| device: torch.device = input.device |
| dtype: torch.dtype = input.dtype |
|
|
| |
| if input.shape[1] == 3: |
| input = rgb_to_grayscale(input) |
|
|
| |
| blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) |
|
|
| |
| gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) |
|
|
| |
| gx: torch.Tensor = gradients[:, :, 0] |
| gy: torch.Tensor = gradients[:, :, 1] |
|
|
| |
| magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) |
| angle: torch.Tensor = torch.atan2(gy, gx) |
|
|
| |
| angle = 180. * angle / math.pi |
|
|
| |
| angle = torch.round(angle / 45) * 45 |
|
|
| |
| nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) |
| nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) |
|
|
| |
| positive_idx: torch.Tensor = (angle / 45) % 8 |
| positive_idx = positive_idx.long() |
|
|
| negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 |
| negative_idx = negative_idx.long() |
|
|
| |
| channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx) |
| channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx) |
|
|
| channel_select_filtered: torch.Tensor = torch.stack( |
| [channel_select_filtered_positive, channel_select_filtered_negative], 1 |
| ) |
|
|
| is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 |
|
|
| magnitude = magnitude * is_max |
|
|
| |
| edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) |
|
|
| low: torch.Tensor = magnitude > low_threshold |
| high: torch.Tensor = magnitude > high_threshold |
|
|
| edges = low * 0.5 + high * 0.5 |
| edges = edges.to(dtype) |
|
|
| |
| if hysteresis: |
| edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) |
| hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) |
|
|
| while ((edges_old - edges).abs() != 0).any(): |
| weak: torch.Tensor = (edges == 0.5).float() |
| strong: torch.Tensor = (edges == 1).float() |
|
|
| hysteresis_magnitude: torch.Tensor = F.conv2d( |
| edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 |
| ) |
| hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) |
| hysteresis_magnitude = hysteresis_magnitude * weak + strong |
|
|
| edges_old = edges.clone() |
| edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 |
|
|
| edges = hysteresis_magnitude |
|
|
| return magnitude, edges |
|
|
|
|
| class Canny(nn.Module): |
| r"""Module that finds edges of the input image and filters them using the Canny algorithm. |
| |
| Args: |
| input: input image tensor with shape :math:`(B,C,H,W)`. |
| low_threshold: lower threshold for the hysteresis procedure. |
| high_threshold: upper threshold for the hysteresis procedure. |
| kernel_size: the size of the kernel for the gaussian blur. |
| sigma: the standard deviation of the kernel for the gaussian blur. |
| hysteresis: if True, applies the hysteresis edge tracking. |
| Otherwise, the edges are divided between weak (0.5) and strong (1) edges. |
| eps: regularization number to avoid NaN during backprop. |
| |
| Returns: |
| - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. |
| - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. |
| |
| Example: |
| >>> input = torch.rand(5, 3, 4, 4) |
| >>> magnitude, edges = Canny()(input) # 5x3x4x4 |
| >>> magnitude.shape |
| torch.Size([5, 1, 4, 4]) |
| >>> edges.shape |
| torch.Size([5, 1, 4, 4]) |
| """ |
|
|
| def __init__( |
| self, |
| low_threshold: float = 0.1, |
| high_threshold: float = 0.2, |
| kernel_size: Tuple[int, int] = (5, 5), |
| sigma: Tuple[float, float] = (1, 1), |
| hysteresis: bool = True, |
| eps: float = 1e-6, |
| ) -> None: |
| super().__init__() |
|
|
| if low_threshold > high_threshold: |
| raise ValueError( |
| "Invalid input thresholds. low_threshold should be\ |
| smaller than the high_threshold. Got: {}>{}".format( |
| low_threshold, high_threshold |
| ) |
| ) |
|
|
| if low_threshold < 0 or low_threshold > 1: |
| raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") |
|
|
| if high_threshold < 0 or high_threshold > 1: |
| raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") |
|
|
| |
| self.kernel_size = kernel_size |
| self.sigma = sigma |
|
|
| |
| self.low_threshold = low_threshold |
| self.high_threshold = high_threshold |
|
|
| |
| self.hysteresis = hysteresis |
|
|
| self.eps: float = eps |
|
|
| def __repr__(self) -> str: |
| return ''.join( |
| ( |
| f'{type(self).__name__}(', |
| ', '.join( |
| f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_') |
| ), |
| ')', |
| ) |
| ) |
|
|
| def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| return canny( |
| input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps |
| ) |
|
|