Spaces:
Runtime error
Runtime error
| import torch | |
| from jaxtyping import Float | |
| def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: | |
| switch_val = 0.04045 | |
| return torch.where( | |
| torch.greater(x, switch_val), | |
| ((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4), | |
| x / 12.92, | |
| ) | |
| def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: | |
| switch_val = 0.0031308 | |
| return torch.where( | |
| torch.greater(x, switch_val), | |
| 1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055, | |
| x * 12.92, | |
| ) | |
| def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: | |
| srgb_pixels = torch.reshape(srgb, [-1, 3]) | |
| linear_mask = srgb_pixels <= 0.04045 | |
| exponential_mask = srgb_pixels > 0.04045 | |
| rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + ( | |
| ((srgb_pixels + 0.055) / 1.055) ** 2.4 | |
| ) * exponential_mask | |
| rgb_to_xyz = ( | |
| torch.tensor( | |
| [ | |
| # X Y Z | |
| [0.412453, 0.212671, 0.019334], # R | |
| [0.357580, 0.715160, 0.119193], # G | |
| [0.180423, 0.072169, 0.950227], # B | |
| ] | |
| ) | |
| .to(srgb.dtype) | |
| .to(srgb.device) | |
| ) | |
| xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz) | |
| xyz_normalized_pixels = torch.mul( | |
| xyz_pixels, | |
| torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device), | |
| ) | |
| epsilon = 6.0 / 29.0 | |
| linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device) | |
| exponential_mask = ( | |
| (xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device) | |
| ) | |
| fxfyfz_pixels = ( | |
| xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0 | |
| ) * linear_mask + ( | |
| (xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0) | |
| ) * exponential_mask | |
| fxfyfz_to_lab = ( | |
| torch.tensor( | |
| [ | |
| # l a b | |
| [0.0, 500.0, 0.0], # fx | |
| [116.0, -500.0, 200.0], # fy | |
| [0.0, 0.0, -200.0], # fz | |
| ] | |
| ) | |
| .to(srgb.dtype) | |
| .to(srgb.device) | |
| ) | |
| lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor( | |
| [-16.0, 0.0, 0.0] | |
| ).to(srgb.dtype).to(srgb.device) | |
| return torch.reshape(lab_pixels, srgb.shape) | |