File size: 18,896 Bytes
dc9bb20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
from collections.abc import Sequence
from functools import wraps
from math import prod, sqrt
from typing import Optional

import torch

from ..._ops import register_kernel
from ..utils import CODE


def _try_torch_compile(func=None, **compile_kwargs):
    """
    Wrapper around torch.compile that falls back to the original function if compilation fails.
    """

    def decorator(fn):
        try:
            compiled_fn = torch.compile(fn, **compile_kwargs)

            @wraps(fn)
            def wrapper(*args, **kwargs):
                try:
                    return compiled_fn(*args, **kwargs)
                except Exception:
                    return fn(*args, **kwargs)

            return wrapper
        except Exception:
            return fn

    if func is None:
        return decorator
    else:
        return decorator(func)


@register_kernel("bitsandbytes::int8_mm_dequant", "default")
def _(
    A: torch.Tensor,
    row_stats: torch.Tensor,
    col_stats: torch.Tensor,
    dtype: Optional[torch.dtype] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
    torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
    torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")

    A_calc = A.view(-1, A.shape[-1])
    row_stats = row_stats.reshape(-1).unsqueeze(-1)
    col_stats = col_stats.reshape(-1).unsqueeze(0)

    out = A_calc * (row_stats * col_stats) * 6.200124e-05
    if bias is not None:
        out += bias

    return out.to(dtype or torch.float16)


@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
def _(
    A: torch.Tensor,
    CA: torch.Tensor,
    CB: torch.Tensor,
    SCA: torch.Tensor,
    SCB: torch.Tensor,
    outlier_cols: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    subB = None

    if outlier_cols is not None and outlier_cols.numel():
        # Extract the inputs with outliers in original precision
        subA = A[:, outlier_cols].contiguous()

        # Dequantize the corresponding weight columns
        subB = (
            torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
            .to(A.dtype)
            .t()
        )

        # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()

    else:
        # Needed for torch.compile when there are no outliers.
        subA = torch.empty(0, device=A.device, dtype=A.dtype)

    # Int8 Matmul + Dequant + Bias
    output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)

    if subB is not None:
        # Add the outlier columns back to the output
        output = output.addmm(subA, subB)

    return output, subA


@register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _(
    A: torch.Tensor,
    B: torch.Tensor,
    row_stats: torch.Tensor,
    col_stats: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
    return torch.ops.bitsandbytes.int8_mm_dequant.default(
        out_i32,
        row_stats,
        col_stats,
        dtype=dtype or torch.float16,
        bias=bias,
    )


@register_kernel("bitsandbytes::int8_linear_matmul", "default")
def _(A: torch.Tensor, B: torch.Tensor):
    return _int8_linear_matmul_impl(A, B)


@register_kernel("bitsandbytes::int8_linear_matmul.out", "default")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
    torch._check(out.dtype == torch.int32)
    _int8_linear_matmul_impl(A, B, out)


def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
    # Naive implementation: perform matmul in fp32
    result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
    if out is not None:
        result = out.copy_(result)
    return result


@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
def _(A: torch.Tensor, threshold=0.0):
    rows = prod(A.shape[:-1])
    outlier_cols = None

    outlier_restore = None

    if threshold > 0.0:
        outliers = A.abs() >= threshold

        if outliers.any():
            # Determine which columns contain outliers, and zero out the
            # outliers ahead of quantization. We need to keep a backup of these
            # outliers to restore them after quantization.
            outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
            outlier_restore = A[outliers].clone()
            A[outliers] = 0
        else:
            # Needed for torch.compile support.
            outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)

    # Get absmax for each row.
    row_stats = torch.max(A.abs(), dim=1).values.float()

    # Quantize row-wise to int8.
    out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)

    # Zero out values from outlier columns across all rows.
    if rows > 1 and outlier_cols is not None:
        out_row[:, outlier_cols] = 0

    # Restore outliers.
    if outlier_restore is not None:
        A[outliers] = outlier_restore

    return out_row, row_stats, outlier_cols


@register_kernel("bitsandbytes::quantize_blockwise", "default")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
    torch._check_is_size(blocksize)

    n = A.numel()
    rem = n % blocksize
    has_rem = rem > 0
    blocks = n // blocksize + has_rem
    absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
    A_reshaped = A.reshape(n)
    A_com = A_reshaped[: n - rem]
    A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
    absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
    scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
    scaled_A = scaled_A.reshape(-1)
    if has_rem:
        absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
        scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
        scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

    diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
    out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

    return out, absmax


@register_kernel("bitsandbytes::dequantize_blockwise", "default")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
    torch._check_is_size(blocksize)
    torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")

    out = code[A.reshape(-1).int()]
    blocks = out.shape[-1] // blocksize
    res = out.shape[-1] % blocksize
    if res != 0:
        out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
    out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
    out = out[: blocks * blocksize + res]
    out = out.reshape(A.shape)

    return out


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
    A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
    torch._check_is_size(blocksize)
    torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
    torch._check(
        A.dtype in [torch.bfloat16, torch.float16, torch.float32],
        lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
    )

    n = A.numel()
    full_blocks = n // blocksize
    rem = n % blocksize
    blocks = full_blocks + 1 if rem else full_blocks
    absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
    A_flattened = A.reshape(n)

    # Scale full blocks of the tensor to [-1, 1]
    A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize)
    absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0]
    scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1)

    # Scale any partial block
    if rem:
        A_rem = A_flattened[-rem:]
        absmax[-1] = torch.abs(A_rem).max()
        scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1)
        scaled = torch.cat([scaled, scaled_rem], dim=0)

    # Quantize with the lookup table
    code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
    quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)

    # Pack two quantized values per byte
    packed = quantized[::2] << 4 | quantized[1::2]

    if quant_storage != torch.uint8:
        packed = packed.squeeze().view(quant_storage).unsqueeze(1)

    return packed, absmax.float()


def _dequantize_4bit_impl(
    A: torch.Tensor,
    absmax: torch.Tensor,
    blocksize: int,
    quant_type: str,
    shape: Sequence[int],
    dtype: torch.dtype,
) -> torch.Tensor:
    # Enable non uint8 dtype
    if A.dtype != torch.uint8:
        A = A.view(torch.uint8)

    A = A.reshape(-1)
    # Map nf4 to [-1, 1]
    out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
    n = out_dq.numel()
    out_dq[1::2] = A & 0xF
    out_dq[::2] = A >> 4
    # code is fp32, cast to dtype to avoid the mismatch issue
    code = CODE[quant_type].to(dtype).to(A.device)
    out_dq = code[out_dq]

    # Apply scales
    if out_dq.numel() != n:
        assert out_dq.numel() == n + 1
        out_dq = torch.narrow(out_dq, 0, 0, n)
    blocks = n // blocksize
    blocks += 1 if n % blocksize > 0 else 0
    rem = n % blocksize
    has_rem = rem > 0

    out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
    if has_rem:
        out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
        out[n - rem :] = out_dq[n - rem :] * absmax[-1]
    else:
        out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)

    out = out.reshape(-1, *shape[1:]).to(dtype)

    return out


@register_kernel("bitsandbytes::dequantize_4bit", "default")
def _(
    A: torch.Tensor,
    absmax: torch.Tensor,
    blocksize: int,
    quant_type: str,
    shape: Sequence[int],
    dtype: torch.dtype,
) -> torch.Tensor:
    torch._check_is_size(blocksize)
    torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
    torch._check(
        dtype in [torch.bfloat16, torch.float16, torch.float32],
        lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
    )

    return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


@register_kernel("bitsandbytes::gemv_4bit", "default")
def _(
    A: torch.Tensor,
    B: torch.Tensor,
    shapeB: Sequence[int],
    absmax: torch.Tensor,
    code: torch.Tensor,
    blocksize: int,
) -> torch.Tensor:
    # Applied from dequantize_4bit
    quant_type = "fp4" if code[1] > 0 else "nf4"
    B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)

    return torch.nn.functional.linear(
        A,
        B_dq,
        bias=None,
    )


MOMENTUM = 0
RMSPROP = 1
ADAGRAD = 2
ADAM = 3
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
LION = 4
ADEMAMIX = 5

name2optimizer_id = {
    "momentum": MOMENTUM,
    "rmsprop": RMSPROP,
    "adagrad": ADAGRAD,
    "adam": ADAM,
    "lion": LION,
    "ademamix": ADEMAMIX,
}


@_try_torch_compile
def _optimizer_precondition_32bit(
    g: torch.Tensor,
    p: torch.Tensor,
    state1: torch.Tensor,
    state2: Optional[torch.Tensor],
    unorm_vec: torch.Tensor,
    beta1: float,
    beta2: float,
    eps: float,
    weight_decay: float,
    step: int,
    lr: float,
    gnorm_scale: float,
    optimizer_id: int,
):
    """Preprocessing optimizer, computing update norm"""

    g_vals = gnorm_scale * g

    if optimizer_id == 3:  # ADAM
        correction1 = 1.0 / (1.0 - beta1**step)
        correction2 = 1.0 / (1.0 - beta2**step)

        s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
        s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals

        s1_vals = s1_vals * correction1
        s2_vals = s2_vals * correction2

        update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
        update_norm = update_vals * update_vals

    elif optimizer_id == 5:  # ADEMAMIX
        update_norm = state1

    elif optimizer_id == 0:  # MOMENTUM
        if step == 1:
            s1_vals = g_vals
        else:
            s1_vals = state1 * beta1 + g_vals
        update_norm = s1_vals * s1_vals

    elif optimizer_id == 4:  # LION
        s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
        update_norm = s1_vals

    elif optimizer_id == 1:  # RMSPROP
        s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
        update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
        update_norm = update_vals * update_vals

    elif optimizer_id == 2:  # ADAGRAD
        s1_vals = state1 + g_vals * g_vals
        update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
        update_norm = update_vals * update_vals

    total_norm = torch.sum(update_norm)
    unorm_vec.add_(total_norm)


@_try_torch_compile
def _optimizer_update_32bit(
    g: torch.Tensor,
    p: torch.Tensor,
    state1: torch.Tensor,
    state2: Optional[torch.Tensor],
    unorm_vec: Optional[torch.Tensor],
    max_unorm: float,
    param_norm: float,
    beta1: float,
    beta2: float,
    beta3: float,
    alpha: float,
    eps: float,
    weight_decay: float,
    step: int,
    lr: float,
    gnorm_scale: float,
    optimizer_id: int,
):
    """Unified optimizer update kernel"""

    p_vals = p.float()
    g_vals = (gnorm_scale * g).float()
    if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
        g_vals = g_vals + p_vals * weight_decay

    update_scale = 1.0
    if max_unorm > 0.0:
        current_unorm = torch.sqrt(unorm_vec)
        if optimizer_id in [0, 1, 2, 4]:  # 1-state optimizers
            if current_unorm > max_unorm * param_norm + eps:
                update_scale = (max_unorm * param_norm + eps) / current_unorm
        else:  # 2-state optimizers
            if current_unorm > max_unorm * param_norm:
                update_scale = (max_unorm * param_norm) / current_unorm

    if optimizer_id == 3:  # ADAM
        s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
        s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals

        correction1 = 1.0 - beta1**step
        correction2 = sqrt(1.0 - beta2**step)
        step_size = -lr * correction2 / correction1

        if weight_decay > 0.0:
            p_vals = p_vals * (1.0 - lr * weight_decay)

        update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
        p_vals = p_vals + update_val

        state1.copy_(s1_vals)
        state2.copy_(s2_vals)

    elif optimizer_id == 5:  # ADEMAMIX
        s1_vals = state1[0]
        s3_vals = state1[1]
        s2_vals = state2

        m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
        m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
        nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals

        correction1 = 1.0 - beta1**step
        correction2 = sqrt(1.0 - beta2**step)

        if weight_decay > 0.0:
            p_vals = p_vals * (1.0 - lr * weight_decay)

        mixed_momentum = (m1 / correction1) + (alpha * m2)
        adaptive_term = (torch.sqrt(nu) / correction2) + eps
        p_vals = p_vals - lr * (mixed_momentum / adaptive_term)

        state1[0].copy_(m1)
        state1[1].copy_(m2)
        state2.copy_(nu)

    elif optimizer_id == 0:  # MOMENTUM
        if step == 1:
            s1_vals = g_vals
        else:
            s1_vals = state1 * beta1 + g_vals

        update_val = update_scale * (-lr * s1_vals)
        p_vals = p_vals + update_val

        state1.copy_(s1_vals)

    elif optimizer_id == 4:  # LION
        momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
        update_val = update_scale * lr * torch.sign(momentum_update)
        p_vals = p_vals - update_val

        s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
        state1.copy_(s1_vals)

    elif optimizer_id == 1:  # RMSPROP
        s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
        update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
        p_vals = p_vals - update_val

        state1.copy_(s1_vals)

    elif optimizer_id == 2:  # ADAGRAD
        s1_vals = state1 + g_vals * g_vals
        update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
        p_vals = p_vals - update_val

        state1.copy_(s1_vals)

    p.copy_(p_vals)


@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
def _(
    optimizer_name: str,
    g: torch.Tensor,
    p: torch.Tensor,
    state1: torch.Tensor,
    state2: Optional[torch.Tensor],
    unorm_vec: Optional[torch.Tensor],
    max_unorm: float,
    param_norm: float,
    beta1: float,
    beta2: float,
    beta3: float,
    alpha: float,
    eps: float,
    weight_decay: float,
    step: int,
    lr: float,
    gnorm_scale: float = 1.0,
    skip_zeros=False,
) -> None:
    """
    32-bit optimizer implemented by PyTorch with @torch.compile
    """
    if skip_zeros:
        raise NotImplementedError("skip_zeros is not supported yet")

    optimizer_id = name2optimizer_id[optimizer_name]

    if optimizer_name == "lion":
        _optimizer_update_32bit(
            g,
            p,
            state1,
            state2,
            unorm_vec,
            max_unorm,
            param_norm,
            beta1,
            beta2,
            beta3,
            alpha,
            eps,
            weight_decay,
            step,
            lr,
            gnorm_scale,
            optimizer_id,
        )

        if max_unorm > 0.0:
            unorm_vec.zero_()
            _optimizer_precondition_32bit(
                g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
            )
    else:
        if max_unorm > 0.0:
            unorm_vec.zero_()
            _optimizer_precondition_32bit(
                g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
            )

        _optimizer_update_32bit(
            g,
            p,
            state1,
            state2,
            unorm_vec,
            max_unorm,
            param_norm,
            beta1,
            beta2,
            beta3,
            alpha,
            eps,
            weight_decay,
            step,
            lr,
            gnorm_scale,
            optimizer_id,
        )