| import math |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: |
| r"""Convert an image from RGB to HSV. |
| |
| .. image:: _static/img/rgb_to_hsv.png |
| |
| The image data is assumed to be in the range of (0, 1). |
| |
| Args: |
| image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. |
| eps: scalar to enforce numarical stability. |
| |
| Returns: |
| HSV version of the image with shape of :math:`(*, 3, H, W)`. |
| The H channel values are in the range 0..2pi. S and V are in the range 0..1. |
| |
| .. note:: |
| See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/ |
| color_conversions.html>`__. |
| |
| Example: |
| >>> input = torch.rand(2, 3, 4, 5) |
| >>> output = rgb_to_hsv(input) # 2x3x4x5 |
| """ |
| if not isinstance(image, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") |
|
|
| if len(image.shape) < 3 or image.shape[-3] != 3: |
| raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") |
|
|
| max_rgb, argmax_rgb = image.max(-3) |
| min_rgb, argmin_rgb = image.min(-3) |
| deltac = max_rgb - min_rgb |
|
|
| v = max_rgb |
| s = deltac / (max_rgb + eps) |
|
|
| deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac) |
| rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3) |
|
|
| h1 = (bc - gc) |
| h2 = (rc - bc) + 2.0 * deltac |
| h3 = (gc - rc) + 4.0 * deltac |
|
|
| h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3) |
| h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3) |
| h = (h / 6.0) % 1.0 |
| h = 2. * math.pi * h |
|
|
| return torch.stack((h, s, v), dim=-3) |
|
|
|
|
| def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor: |
| r"""Convert an image from HSV to RGB. |
| |
| The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. |
| |
| Args: |
| image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. |
| |
| Returns: |
| RGB version of the image with shape of :math:`(*, 3, H, W)`. |
| |
| Example: |
| >>> input = torch.rand(2, 3, 4, 5) |
| >>> output = hsv_to_rgb(input) # 2x3x4x5 |
| """ |
| if not isinstance(image, torch.Tensor): |
| raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") |
|
|
| if len(image.shape) < 3 or image.shape[-3] != 3: |
| raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") |
|
|
| h: torch.Tensor = image[..., 0, :, :] / (2 * math.pi) |
| s: torch.Tensor = image[..., 1, :, :] |
| v: torch.Tensor = image[..., 2, :, :] |
|
|
| hi: torch.Tensor = torch.floor(h * 6) % 6 |
| f: torch.Tensor = ((h * 6) % 6) - hi |
| one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype) |
| p: torch.Tensor = v * (one - s) |
| q: torch.Tensor = v * (one - f * s) |
| t: torch.Tensor = v * (one - (one - f) * s) |
|
|
| hi = hi.long() |
| indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3) |
| out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3) |
| out = torch.gather(out, -3, indices) |
|
|
| return out |
|
|
|
|
| class RgbToHsv(nn.Module): |
| r"""Convert an image from RGB to HSV. |
| |
| The image data is assumed to be in the range of (0, 1). |
| |
| Args: |
| eps: scalar to enforce numarical stability. |
| |
| Returns: |
| HSV version of the image. |
| |
| Shape: |
| - image: :math:`(*, 3, H, W)` |
| - output: :math:`(*, 3, H, W)` |
| |
| Example: |
| >>> input = torch.rand(2, 3, 4, 5) |
| >>> hsv = RgbToHsv() |
| >>> output = hsv(input) # 2x3x4x5 |
| """ |
|
|
| def __init__(self, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.eps = eps |
|
|
| def forward(self, image: torch.Tensor) -> torch.Tensor: |
| return rgb_to_hsv(image, self.eps) |
|
|
|
|
| class HsvToRgb(nn.Module): |
| r"""Convert an image from HSV to RGB. |
| |
| H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. |
| |
| Returns: |
| RGB version of the image. |
| |
| Shape: |
| - image: :math:`(*, 3, H, W)` |
| - output: :math:`(*, 3, H, W)` |
| |
| Example: |
| >>> input = torch.rand(2, 3, 4, 5) |
| >>> rgb = HsvToRgb() |
| >>> output = rgb(input) # 2x3x4x5 |
| """ |
|
|
| def forward(self, image: torch.Tensor) -> torch.Tensor: |
| return hsv_to_rgb(image) |
|
|