ReSWD / src /utils /color_space.py
mboss's picture
Initial commit
7349148
raw
history blame
2.36 kB
import torch
from jaxtyping import Float
def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
switch_val = 0.04045
return torch.where(
torch.greater(x, switch_val),
((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4),
x / 12.92,
)
def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
switch_val = 0.0031308
return torch.where(
torch.greater(x, switch_val),
1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055,
x * 12.92,
)
def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
srgb_pixels = torch.reshape(srgb, [-1, 3])
linear_mask = srgb_pixels <= 0.04045
exponential_mask = srgb_pixels > 0.04045
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (
((srgb_pixels + 0.055) / 1.055) ** 2.4
) * exponential_mask
rgb_to_xyz = (
torch.tensor(
[
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
]
)
.to(srgb.dtype)
.to(srgb.device)
)
xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz)
xyz_normalized_pixels = torch.mul(
xyz_pixels,
torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device),
)
epsilon = 6.0 / 29.0
linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device)
exponential_mask = (
(xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device)
)
fxfyfz_pixels = (
xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0
) * linear_mask + (
(xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0)
) * exponential_mask
fxfyfz_to_lab = (
torch.tensor(
[
# l a b
[0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[0.0, 0.0, -200.0], # fz
]
)
.to(srgb.dtype)
.to(srgb.device)
)
lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor(
[-16.0, 0.0, 0.0]
).to(srgb.dtype).to(srgb.device)
return torch.reshape(lab_pixels, srgb.shape)