| | import torch |
| | from src.utils.color import to_float_rgb |
| |
|
| |
|
| | __all__ = ['rgb2hsv', 'rgb2lab'] |
| |
|
| |
|
| | def rgb2hsv(rgb, epsilon=1e-10): |
| | """Convert a 2D tensor of RGB colors int [0, 255] or float [0, 1] to |
| | HSV format. |
| | |
| | Credit: https://www.linuxtut.com/en/20819a90872275811439 |
| | """ |
| | assert rgb.ndim == 2 |
| | assert rgb.shape[1] == 3 |
| |
|
| | rgb = rgb.clone() |
| |
|
| | |
| | rgb = to_float_rgb(rgb) |
| |
|
| | r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2] |
| | max_rgb, argmax_rgb = rgb.max(1) |
| | min_rgb, argmin_rgb = rgb.min(1) |
| |
|
| | max_min = max_rgb - min_rgb + epsilon |
| |
|
| | h1 = 60.0 * (g - r) / max_min + 60.0 |
| | h2 = 60.0 * (b - g) / max_min + 180.0 |
| | h3 = 60.0 * (r - b) / max_min + 300.0 |
| |
|
| | h = torch.stack((h2, h3, h1), dim=0).gather( |
| | dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0) |
| | s = max_min / (max_rgb + epsilon) |
| | v = max_rgb |
| |
|
| | return torch.stack((h, s, v), dim=1) |
| |
|
| |
|
| | def rgb2lab(rgb): |
| | """Convert a tensor of RGB colors int[0, 255] or float [0, 1] to LAB |
| | colors. |
| | |
| | Reimplemented from: |
| | https://gist.github.com/manojpandey/f5ece715132c572c80421febebaf66ae |
| | """ |
| | rgb = rgb.clone() |
| | device = rgb.device |
| |
|
| | |
| | rgb = to_float_rgb(rgb) |
| |
|
| | |
| | mask = rgb > 0.04045 |
| | rgb[mask] = ((rgb[mask] + 0.055) / 1.055) ** 2.4 |
| | rgb[~mask] = rgb[~mask] / 12.92 |
| | rgb *= 100 |
| |
|
| | |
| | m = torch.tensor([ |
| | [0.4124, 0.2126, 0.0193], |
| | [0.3576, 0.7152, 0.1192], |
| | [0.1805, 0.0722, 0.9505]], device=device) |
| | xyz = (rgb @ m).round(decimals=4) |
| |
|
| | |
| | |
| | scale = torch.tensor([[95.047, 100.0, 108.883]], device=device) |
| | xyz /= scale |
| |
|
| | |
| | mask = xyz > 0.008856 |
| | xyz[mask] = xyz[mask] ** (1 / 3.) |
| | xyz[~mask] = 7.787 * xyz[~mask] + 1 / 7.25 |
| |
|
| | |
| | lab = torch.zeros_like(xyz) |
| | m = torch.tensor([ |
| | [0, 500, 0], |
| | [116, -500, 200], |
| | [0, 0, -200]], device=device, dtype=torch.float) |
| | lab = xyz @ m |
| | lab[:, 0] -= 16 |
| | lab = lab.round(decimals=4) |
| |
|
| | return lab |
| |
|