|
|
import torch |
|
|
from src.utils.color import to_float_rgb |
|
|
|
|
|
|
|
|
__all__ = ['rgb2hsv', 'rgb2lab'] |
|
|
|
|
|
|
|
|
def rgb2hsv(rgb, epsilon=1e-10): |
|
|
"""Convert a 2D tensor of RGB colors int [0, 255] or float [0, 1] to |
|
|
HSV format. |
|
|
|
|
|
Credit: https://www.linuxtut.com/en/20819a90872275811439 |
|
|
""" |
|
|
assert rgb.ndim == 2 |
|
|
assert rgb.shape[1] == 3 |
|
|
|
|
|
rgb = rgb.clone() |
|
|
|
|
|
|
|
|
rgb = to_float_rgb(rgb) |
|
|
|
|
|
r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2] |
|
|
max_rgb, argmax_rgb = rgb.max(1) |
|
|
min_rgb, argmin_rgb = rgb.min(1) |
|
|
|
|
|
max_min = max_rgb - min_rgb + epsilon |
|
|
|
|
|
h1 = 60.0 * (g - r) / max_min + 60.0 |
|
|
h2 = 60.0 * (b - g) / max_min + 180.0 |
|
|
h3 = 60.0 * (r - b) / max_min + 300.0 |
|
|
|
|
|
h = torch.stack((h2, h3, h1), dim=0).gather( |
|
|
dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0) |
|
|
s = max_min / (max_rgb + epsilon) |
|
|
v = max_rgb |
|
|
|
|
|
return torch.stack((h, s, v), dim=1) |
|
|
|
|
|
|
|
|
def rgb2lab(rgb): |
|
|
"""Convert a tensor of RGB colors int[0, 255] or float [0, 1] to LAB |
|
|
colors. |
|
|
|
|
|
Reimplemented from: |
|
|
https://gist.github.com/manojpandey/f5ece715132c572c80421febebaf66ae |
|
|
""" |
|
|
rgb = rgb.clone() |
|
|
device = rgb.device |
|
|
|
|
|
|
|
|
rgb = to_float_rgb(rgb) |
|
|
|
|
|
|
|
|
mask = rgb > 0.04045 |
|
|
rgb[mask] = ((rgb[mask] + 0.055) / 1.055) ** 2.4 |
|
|
rgb[~mask] = rgb[~mask] / 12.92 |
|
|
rgb *= 100 |
|
|
|
|
|
|
|
|
m = torch.tensor([ |
|
|
[0.4124, 0.2126, 0.0193], |
|
|
[0.3576, 0.7152, 0.1192], |
|
|
[0.1805, 0.0722, 0.9505]], device=device) |
|
|
xyz = (rgb @ m).round(decimals=4) |
|
|
|
|
|
|
|
|
|
|
|
scale = torch.tensor([[95.047, 100.0, 108.883]], device=device) |
|
|
xyz /= scale |
|
|
|
|
|
|
|
|
mask = xyz > 0.008856 |
|
|
xyz[mask] = xyz[mask] ** (1 / 3.) |
|
|
xyz[~mask] = 7.787 * xyz[~mask] + 1 / 7.25 |
|
|
|
|
|
|
|
|
lab = torch.zeros_like(xyz) |
|
|
m = torch.tensor([ |
|
|
[0, 500, 0], |
|
|
[116, -500, 200], |
|
|
[0, 0, -200]], device=device, dtype=torch.float) |
|
|
lab = xyz @ m |
|
|
lab[:, 0] -= 16 |
|
|
lab = lab.round(decimals=4) |
|
|
|
|
|
return lab |
|
|
|