| | import torch |
| | from jaxtyping import Float |
| |
|
| |
|
| | def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| | switch_val = 0.04045 |
| | return torch.where( |
| | torch.greater(x, switch_val), |
| | ((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4), |
| | x / 12.92, |
| | ) |
| |
|
| |
|
| | def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| | switch_val = 0.0031308 |
| | return torch.where( |
| | torch.greater(x, switch_val), |
| | 1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055, |
| | x * 12.92, |
| | ) |
| |
|
| |
|
| | def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| | srgb_pixels = torch.reshape(srgb, [-1, 3]) |
| |
|
| | linear_mask = srgb_pixels <= 0.04045 |
| | exponential_mask = srgb_pixels > 0.04045 |
| | rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + ( |
| | ((srgb_pixels + 0.055) / 1.055) ** 2.4 |
| | ) * exponential_mask |
| |
|
| | rgb_to_xyz = ( |
| | torch.tensor( |
| | [ |
| | |
| | [0.412453, 0.212671, 0.019334], |
| | [0.357580, 0.715160, 0.119193], |
| | [0.180423, 0.072169, 0.950227], |
| | ] |
| | ) |
| | .to(srgb.dtype) |
| | .to(srgb.device) |
| | ) |
| |
|
| | xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz) |
| |
|
| | xyz_normalized_pixels = torch.mul( |
| | xyz_pixels, |
| | torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device), |
| | ) |
| |
|
| | epsilon = 6.0 / 29.0 |
| | linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device) |
| |
|
| | exponential_mask = ( |
| | (xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device) |
| | ) |
| |
|
| | fxfyfz_pixels = ( |
| | xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0 |
| | ) * linear_mask + ( |
| | (xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0) |
| | ) * exponential_mask |
| |
|
| | fxfyfz_to_lab = ( |
| | torch.tensor( |
| | [ |
| | |
| | [0.0, 500.0, 0.0], |
| | [116.0, -500.0, 200.0], |
| | [0.0, 0.0, -200.0], |
| | ] |
| | ) |
| | .to(srgb.dtype) |
| | .to(srgb.device) |
| | ) |
| | lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor( |
| | [-16.0, 0.0, 0.0] |
| | ).to(srgb.dtype).to(srgb.device) |
| | return torch.reshape(lab_pixels, srgb.shape) |
| |
|