File size: 2,355 Bytes
7349148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from jaxtyping import Float


def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
    switch_val = 0.04045
    return torch.where(
        torch.greater(x, switch_val),
        ((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4),
        x / 12.92,
    )


def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
    switch_val = 0.0031308
    return torch.where(
        torch.greater(x, switch_val),
        1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055,
        x * 12.92,
    )


def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]:
    srgb_pixels = torch.reshape(srgb, [-1, 3])

    linear_mask = srgb_pixels <= 0.04045
    exponential_mask = srgb_pixels > 0.04045
    rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (
        ((srgb_pixels + 0.055) / 1.055) ** 2.4
    ) * exponential_mask

    rgb_to_xyz = (
        torch.tensor(
            [
                #    X        Y          Z
                [0.412453, 0.212671, 0.019334],  # R
                [0.357580, 0.715160, 0.119193],  # G
                [0.180423, 0.072169, 0.950227],  # B
            ]
        )
        .to(srgb.dtype)
        .to(srgb.device)
    )

    xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz)

    xyz_normalized_pixels = torch.mul(
        xyz_pixels,
        torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device),
    )

    epsilon = 6.0 / 29.0
    linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device)

    exponential_mask = (
        (xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device)
    )

    fxfyfz_pixels = (
        xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0
    ) * linear_mask + (
        (xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0)
    ) * exponential_mask

    fxfyfz_to_lab = (
        torch.tensor(
            [
                #  l       a       b
                [0.0, 500.0, 0.0],  # fx
                [116.0, -500.0, 200.0],  # fy
                [0.0, 0.0, -200.0],  # fz
            ]
        )
        .to(srgb.dtype)
        .to(srgb.device)
    )
    lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor(
        [-16.0, 0.0, 0.0]
    ).to(srgb.dtype).to(srgb.device)
    return torch.reshape(lab_pixels, srgb.shape)