File size: 2,192 Bytes
26225c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from src.utils.color import to_float_rgb
__all__ = ['rgb2hsv', 'rgb2lab']
def rgb2hsv(rgb, epsilon=1e-10):
"""Convert a 2D tensor of RGB colors int [0, 255] or float [0, 1] to
HSV format.
Credit: https://www.linuxtut.com/en/20819a90872275811439
"""
assert rgb.ndim == 2
assert rgb.shape[1] == 3
rgb = rgb.clone()
# Convert colors to float in [0, 1]
rgb = to_float_rgb(rgb)
r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]
max_rgb, argmax_rgb = rgb.max(1)
min_rgb, argmin_rgb = rgb.min(1)
max_min = max_rgb - min_rgb + epsilon
h1 = 60.0 * (g - r) / max_min + 60.0
h2 = 60.0 * (b - g) / max_min + 180.0
h3 = 60.0 * (r - b) / max_min + 300.0
h = torch.stack((h2, h3, h1), dim=0).gather(
dim=0, index=argmin_rgb.unsqueeze(0)).squeeze(0)
s = max_min / (max_rgb + epsilon)
v = max_rgb
return torch.stack((h, s, v), dim=1)
def rgb2lab(rgb):
"""Convert a tensor of RGB colors int[0, 255] or float [0, 1] to LAB
colors.
Reimplemented from:
https://gist.github.com/manojpandey/f5ece715132c572c80421febebaf66ae
"""
rgb = rgb.clone()
device = rgb.device
# Convert colors to float in [0, 1]
rgb = to_float_rgb(rgb)
# Prepare RGB to XYZ
mask = rgb > 0.04045
rgb[mask] = ((rgb[mask] + 0.055) / 1.055) ** 2.4
rgb[~mask] = rgb[~mask] / 12.92
rgb *= 100
# RGB to XYZ conversion
m = torch.tensor([
[0.4124, 0.2126, 0.0193],
[0.3576, 0.7152, 0.1192],
[0.1805, 0.0722, 0.9505]], device=device)
xyz = (rgb @ m).round(decimals=4)
# Observer=2°, Illuminant=D6
# ref_X=95.047, ref_Y=100.000, ref_Z=108.883
scale = torch.tensor([[95.047, 100.0, 108.883]], device=device)
xyz /= scale
# Prepare XYZ for LAB
mask = xyz > 0.008856
xyz[mask] = xyz[mask] ** (1 / 3.)
xyz[~mask] = 7.787 * xyz[~mask] + 1 / 7.25
# XYZ to LAB conversion
lab = torch.zeros_like(xyz)
m = torch.tensor([
[0, 500, 0],
[116, -500, 200],
[0, 0, -200]], device=device, dtype=torch.float)
lab = xyz @ m
lab[:, 0] -= 16
lab = lab.round(decimals=4)
return lab
|