File size: 8,612 Bytes
27fdf85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""

OKLab Color Space Utilities



Perceptually uniform color space for semantic loss computation.

OKLab ensures that equal distances in the color space correspond to

equal perceived differences β€” critical for meaningful color-based encoding.



Key functions:

- srgb_to_oklab / oklab_to_srgb: Color space conversions

- rotate_ab: Rotate hue in a-b plane (for domain/idiom shifts)

- set_chroma: Set chroma magnitude (for purity encoding)

- OKLabMSELoss: Perceptually uniform loss function

- hsl_to_oklab_batch: Batch conversion for training

"""

import torch
import torch.nn as nn
import math
from typing import Tuple


def clamp(x: float, lo: float, hi: float) -> float:
    """Clamp a value to [lo, hi]."""
    return max(lo, min(hi, x))


# ── sRGB ↔ Linear RGB ──

def srgb_to_linear(c: float) -> float:
    """sRGB gamma to linear."""
    if c <= 0.04045:
        return c / 12.92
    return ((c + 0.055) / 1.055) ** 2.4


def linear_to_srgb(c: float) -> float:
    """Linear to sRGB gamma."""
    if c <= 0.0031308:
        return c * 12.92
    return 1.055 * (c ** (1.0 / 2.4)) - 0.055


# ── sRGB ↔ OKLab ──

def srgb_to_oklab(r: float, g: float, b: float) -> Tuple[float, float, float]:
    """Convert sRGB [0,1] to OKLab."""
    r_lin = srgb_to_linear(r)
    g_lin = srgb_to_linear(g)
    b_lin = srgb_to_linear(b)

    l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
    m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
    s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin

    l_c = l_ ** (1.0 / 3.0) if l_ >= 0 else -((-l_) ** (1.0 / 3.0))
    m_c = m_ ** (1.0 / 3.0) if m_ >= 0 else -((-m_) ** (1.0 / 3.0))
    s_c = s_ ** (1.0 / 3.0) if s_ >= 0 else -((-s_) ** (1.0 / 3.0))

    L = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
    a = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
    b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c

    return (L, a, b_ok)


def oklab_to_srgb(L: float, a: float, b_ok: float) -> Tuple[float, float, float]:
    """Convert OKLab to sRGB [0,1]."""
    l_c = L + 0.3963377774 * a + 0.2158037573 * b_ok
    m_c = L - 0.1055613458 * a - 0.0638541728 * b_ok
    s_c = L - 0.0894841775 * a - 1.2914855480 * b_ok

    l_ = l_c * l_c * l_c
    m_ = m_c * m_c * m_c
    s_ = s_c * s_c * s_c

    r_lin = +4.0767416621 * l_ - 3.3077115913 * m_ + 0.2309699292 * s_
    g_lin = -1.2684380046 * l_ + 2.6097574011 * m_ - 0.3413193965 * s_
    b_lin = -0.0041960863 * l_ - 0.7034186147 * m_ + 1.7076147010 * s_

    r = clamp(linear_to_srgb(clamp(r_lin, 0, 1)), 0, 1)
    g = clamp(linear_to_srgb(clamp(g_lin, 0, 1)), 0, 1)
    b = clamp(linear_to_srgb(clamp(b_lin, 0, 1)), 0, 1)

    return (r, g, b)


# ── HSL ↔ RGB ──

def hsl_to_rgb(h_deg: float, s_pct: float, l_pct: float) -> Tuple[float, float, float]:
    """Convert HSL (degrees, percent, percent) to RGB [0,1]."""
    h = h_deg / 360.0
    s = s_pct / 100.0
    l = l_pct / 100.0

    if s == 0:
        return (l, l, l)

    def hue_to_rgb(p, q, t):
        if t < 0: t += 1
        if t > 1: t -= 1
        if t < 1/6: return p + (q - p) * 6 * t
        if t < 1/2: return q
        if t < 2/3: return p + (q - p) * (2/3 - t) * 6
        return p

    q = l * (1 + s) if l < 0.5 else l + s - l * s
    p = 2 * l - q

    r = hue_to_rgb(p, q, h + 1/3)
    g = hue_to_rgb(p, q, h)
    b = hue_to_rgb(p, q, h - 1/3)

    return (r, g, b)


def rgb_to_hsl(r: float, g: float, b: float) -> Tuple[float, float, float]:
    """Convert RGB [0,1] to HSL (degrees, percent, percent)."""
    max_c = max(r, g, b)
    min_c = min(r, g, b)
    l = (max_c + min_c) / 2.0

    if max_c == min_c:
        h = s = 0.0
    else:
        d = max_c - min_c
        s = d / (2.0 - max_c - min_c) if l > 0.5 else d / (max_c + min_c)

        if max_c == r:
            h = (g - b) / d + (6 if g < b else 0)
        elif max_c == g:
            h = (b - r) / d + 2
        else:
            h = (r - g) / d + 4

        h /= 6.0

    return (h * 360.0, s * 100.0, l * 100.0)


# ── OKLab Operations ──

def rotate_ab(a: float, b: float, degrees: float) -> Tuple[float, float]:
    """Rotate hue in OKLab a-b plane by given degrees."""
    rad = math.radians(degrees)
    cos_r = math.cos(rad)
    sin_r = math.sin(rad)
    return (a * cos_r - b * sin_r, a * sin_r + b * cos_r)


def set_chroma(a: float, b: float, target_c: float) -> Tuple[float, float]:
    """Set the chroma (magnitude in a-b plane) to target value."""
    current_c = math.sqrt(a * a + b * b)
    if current_c < 1e-10:
        return (target_c, 0.0)  # Default direction
    scale = target_c / current_c
    return (a * scale, b * scale)


def get_chroma(a: float, b: float) -> float:
    """Get chroma magnitude from a-b values."""
    return math.sqrt(a * a + b * b)


def compute_delta_e_oklab(

    L1: float, a1: float, b1: float,

    L2: float, a2: float, b2: float,

) -> float:
    """Compute Ξ”E in OKLab space (perceptual color difference)."""
    return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2)


# ── Batch Operations (PyTorch) ──

def hsl_to_oklab_batch(hsl: torch.Tensor) -> torch.Tensor:
    """

    Batch convert HSL [0,1] normalized to OKLab.



    Args:

        hsl: (..., 3) tensor with H,S,L in [0,1]



    Returns:

        (..., 3) tensor with L,a,b in OKLab

    """
    h = hsl[..., 0] * 360.0  # Back to degrees
    s = hsl[..., 1] * 100.0  # Back to percent
    l = hsl[..., 2] * 100.0  # Back to percent

    # HSL to RGB (vectorized)
    h_norm = h / 360.0
    q = torch.where(l / 100.0 < 0.5,
                    (l / 100.0) * (1 + s / 100.0),
                    (l / 100.0) + (s / 100.0) - (l / 100.0) * (s / 100.0))
    p = 2 * (l / 100.0) - q

    def hue2rgb(p, q, t):
        t = t % 1.0
        r = torch.where(t < 1/6, p + (q - p) * 6 * t,
            torch.where(t < 1/2, q,
            torch.where(t < 2/3, p + (q - p) * (2/3 - t) * 6, p)))
        return r

    r = hue2rgb(p, q, h_norm + 1/3)
    g = hue2rgb(p, q, h_norm)
    b = hue2rgb(p, q, h_norm - 1/3)

    # Handle achromatic (s == 0)
    achromatic = (s < 0.001)
    r = torch.where(achromatic, l / 100.0, r)
    g = torch.where(achromatic, l / 100.0, g)
    b = torch.where(achromatic, l / 100.0, b)

    # sRGB to linear
    r_lin = torch.where(r <= 0.04045, r / 12.92, ((r + 0.055) / 1.055) ** 2.4)
    g_lin = torch.where(g <= 0.04045, g / 12.92, ((g + 0.055) / 1.055) ** 2.4)
    b_lin = torch.where(b <= 0.04045, b / 12.92, ((b + 0.055) / 1.055) ** 2.4)

    # Linear RGB to OKLab
    l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
    m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
    s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin

    l_c = torch.sign(l_) * torch.abs(l_).pow(1/3)
    m_c = torch.sign(m_) * torch.abs(m_).pow(1/3)
    s_c = torch.sign(s_) * torch.abs(s_).pow(1/3)

    L_ok = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
    a_ok = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
    b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c

    return torch.stack([L_ok, a_ok, b_ok], dim=-1)


def denormalize_hsl(hsl_norm: torch.Tensor) -> torch.Tensor:
    """Convert normalized HSL [0,1] to degrees/percent format."""
    result = hsl_norm.clone()
    result[..., 0] *= 360.0  # H: [0,1] β†’ [0,360]
    result[..., 1] *= 100.0  # S: [0,1] β†’ [0,100]
    result[..., 2] *= 100.0  # L: [0,1] β†’ [0,100]
    return result


class OKLabMSELoss(nn.Module):
    """

    Perceptually uniform loss in OKLab space.



    Converts predicted and target HSL values to OKLab, then computes MSE.

    This handles hue circularity correctly (359Β° β‰ˆ 1Β°) because OKLab

    represents hue as a-b coordinates, not an angle.

    """

    def __init__(self):
        super().__init__()

    def forward(

        self,

        pred_hsl: torch.Tensor,    # (B, 3) predicted HSL in [0,1]

        target_hsl: torch.Tensor,  # (B, 3) target HSL in [0,1]

    ) -> torch.Tensor:
        """Compute perceptually uniform loss."""
        pred_oklab = hsl_to_oklab_batch(pred_hsl)
        target_oklab = hsl_to_oklab_batch(target_hsl)

        return torch.nn.functional.mse_loss(pred_oklab, target_oklab)