from collections.abc import Sequence import ctypes as ct from math import prod from typing import Optional import torch from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel from ...cextension import ROCM_WARP_SIZE_64, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") def _(A: torch.Tensor, B: torch.Tensor): out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) return _int8_linear_matmul_impl(A, B, out) @register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): _int8_linear_matmul_impl(A, B, out) def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): A, B = B, A shapeA = A.shape shapeB = B.shape torch._check(A.dtype == torch.int8, lambda: "B must be int8") torch._check(B.dtype == torch.int8, lambda: "A must be int8") torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") torch._check(out.dtype == torch.int32) shapeC = (*shapeB[:-1], shapeA[0]) torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") k, m = shapeA n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) ldb = shapeB[-1] # Activations (batch, tokens, inputs) ldc = shapeC[-1] # Output (batch, tokens, outputs) torch._check( lda == ldb, lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", ) # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. # Fortunately, this should not be very common. if lda % 4 != 0: result = torch.matmul(B.float(), A.float().t()).to(torch.int32) return out.copy_(result) with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) ptrRowScale = None m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) if has_error: if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` # TODO: Warn and implement a fallback to fp32 compute? raise NotImplementedError("int8_linear_matmul not implemented!") else: raise RuntimeError( f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" ) return out @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") # Note: cuda kernel only currently supports fp16 output. # We'll later cast to desired dtype if needed. out = torch.empty_like(A, dtype=torch.float16) ptrA = get_ptr(A) ptrOut = get_ptr(out) ptrRowStats = get_ptr(row_stats) ptrColStats = get_ptr(col_stats) numRows = ct.c_int32(prod(A.shape[:-1])) numCols = ct.c_int32(A.shape[-1]) # Note: fused bias in the kernel is only supported for fp16 # TODO(matthewdouglas): Consider supporting bf16 fused bias ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) ) # Add bias separately if not fused in kernel if bias is not None and bias.dtype != torch.float16: out.add_(bias) return out.to(dtype or torch.float16) @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") def _(A: torch.Tensor, threshold=0.0): torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") rows = prod(A.shape[:-1]) cols = A.shape[-1] row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) outlier_cols = None if threshold > 0.0: # TODO we could improve perf of this outliers = A.abs() >= threshold if outliers.any(): outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) else: # Needed for torch.compile support. outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) with _cuda_device_of(A): lib.cint8_vector_quant( get_ptr(A), get_ptr(out_row), get_ptr(row_stats), ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols), _get_tensor_stream(A), ) # Zero out values from outlier columns across all rows. # The kernel will handle this for outliers themselves, so we can optimize for rows=1. if rows > 1 and outlier_cols is not None: out_row[:, outlier_cols] = 0 return out_row, row_stats, outlier_cols @register_kernel("bitsandbytes::int8_double_quant", "cuda") def _( A: torch.Tensor, threshold=0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # Use CUDA kernel for rowwise and COO tensor quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( A, threshold=threshold, ) # PyTorch impl for colwise col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) if threshold > 0.0 and outlier_mask is not None: A = A.masked_fill(outlier_mask, 0.0) quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols def _get_col_absmax( A: torch.Tensor, threshold=0.0, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: torch._check(A.is_floating_point()) outlier_mask = None absA = A.abs().view(-1, A.shape[-1]) if threshold > 0.0: # Filter outliers from stats when enabled outlier_mask = absA >= threshold absA.masked_fill_(outlier_mask, 0.0) # shape [cols]; unsqueeze(0) gives [1,cols] col_stats = absA.amax(dim=0, keepdim=False).float() return col_stats, outlier_mask @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() torch._check_is_size(blocksize) if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) with _cuda_device_of(A): args = ( get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(A.numel()), ) if A.dtype == torch.float16: lib.cquantize_blockwise_fp16(*args) elif A.dtype == torch.bfloat16: lib.cquantize_blockwise_bf16(*args) elif A.dtype == torch.float32: lib.cquantize_blockwise_fp32(*args) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") return out, absmax @register_kernel("bitsandbytes::dequantize_blockwise", "cuda") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: out = torch.empty_like(A, dtype=dtype) _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) return out @register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor, ) -> None: torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) with _cuda_device_of(A): args = ( get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()), _get_tensor_stream(A), ) if dtype == torch.float16: lib.cdequantize_blockwise_fp16(*args) elif dtype == torch.bfloat16: lib.cdequantize_blockwise_bf16(*args) elif dtype == torch.float32: lib.cdequantize_blockwise_fp32(*args) @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", ) n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) with _cuda_device_of(A): args = ( None, get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int32(n), ) if A.dtype == torch.bfloat16: if quant_type == "fp4": lib.cquantize_blockwise_bf16_fp4(*args) else: lib.cquantize_blockwise_bf16_nf4(*args) elif A.dtype == torch.float16: if quant_type == "fp4": lib.cquantize_blockwise_fp16_fp4(*args) else: lib.cquantize_blockwise_fp16_nf4(*args) elif A.dtype == torch.float32: if quant_type == "fp4": lib.cquantize_blockwise_fp32_fp4(*args) else: lib.cquantize_blockwise_fp32_nf4(*args) return out, absmax @register_kernel("bitsandbytes::dequantize_4bit", "cuda") def _( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: out = torch.empty(shape, dtype=dtype, device=A.device) _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") def _( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype, out: torch.Tensor, ) -> None: torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, dtype: torch.dtype, out: torch.Tensor, ) -> None: A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) with _cuda_device_of(A): args = ( None, get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int32(out.numel()), _get_tensor_stream(A), ) if out.dtype == torch.bfloat16: if quant_type == "fp4": lib.cdequantize_blockwise_bf16_fp4(*args) else: lib.cdequantize_blockwise_bf16_nf4(*args) elif out.dtype == torch.float16: if quant_type == "fp4": lib.cdequantize_blockwise_fp16_fp4(*args) else: lib.cdequantize_blockwise_fp16_nf4(*args) elif out.dtype == torch.float32: if quant_type == "fp4": lib.cdequantize_blockwise_fp32_fp4(*args) else: lib.cdequantize_blockwise_fp32_nf4(*args) @register_kernel("bitsandbytes::gemv_4bit", "cuda") def _( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int ) -> torch.Tensor: shape = (*A.shape[:-1], shapeB[0]) out = torch.empty(shape, device=A.device, dtype=A.dtype) _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) return out @register_kernel("bitsandbytes::gemv_4bit.out", "cuda") def _( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int, out: torch.Tensor, ) -> None: torch._check( out.shape == (*A.shape[:-1], shapeB[0]), lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", ) torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) def _gemv_4bit_impl( A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int, out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) # Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now. # torch._check( # A.numel() == A.size(-1), # lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", # ) # torch._check( # A.dtype in [torch.float16, torch.bfloat16, torch.float32], # lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", # ) # torch._check( # B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], # lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", # ) # torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") # torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") m = ct.c_int32(shapeB[0]) n = ct.c_int32(1) k = ct.c_int32(shapeB[1]) lda = m ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m stream = _get_tensor_stream(A) with _cuda_device_of(A): if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16( m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(code), get_ptr(out), lda, ldb, ldc, ct.c_int32(blocksize), stream, ) elif A.dtype == torch.bfloat16: lib.cgemm_4bit_inference_naive_bf16( m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(code), get_ptr(out), lda, ldb, ldc, ct.c_int32(blocksize), stream, ) elif A.dtype == torch.float32: lib.cgemm_4bit_inference_naive_fp32( m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(code), get_ptr(out), lda, ldb, ldc, ct.c_int32(blocksize), stream, ) """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16, ), "momentum": ( lib.cmomentum32bit_grad_32, lib.cmomentum32bit_grad_16, ), "rmsprop": ( lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_16, ), "lion": ( lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16, ), "adagrad": ( lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_16, ), "lamb": ( lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16, ), "ademamix": ( lib.cademamix32bit_grad_fp32, lib.cademamix32bit_grad_fp16, lib.cademamix32bit_grad_bf16, ), } str2optimizer8bit_blockwise = { "adam": ( lib.cadam_8bit_blockwise_grad_fp32, lib.cadam_8bit_blockwise_grad_fp16, lib.cadam_8bit_blockwise_grad_bf16, ), "momentum": ( lib.cmomentum_8bit_blockwise_grad_fp32, lib.cmomentum_8bit_blockwise_grad_fp16, lib.cmomentum_8bit_blockwise_grad_bf16, ), "rmsprop": ( lib.crmsprop_8bit_blockwise_grad_fp32, lib.crmsprop_8bit_blockwise_grad_fp16, lib.crmsprop_8bit_blockwise_grad_bf16, ), "lion": ( lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp16, lib.clion_8bit_blockwise_grad_bf16, ), "adagrad": ( lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp16, lib.cadagrad_8bit_blockwise_grad_bf16, ), "ademamix": ( lib.cademamix_8bit_blockwise_grad_fp32, lib.cademamix_8bit_blockwise_grad_fp16, lib.cademamix_8bit_blockwise_grad_bf16, ), } def _optimizer_update_32bit_impl( optimizer_name: str, g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, state2: Optional[torch.Tensor], unorm_vec: Optional[torch.Tensor], max_unorm: float, param_norm: float, beta1: float, beta2: float, beta3: float, alpha: float, eps: float, weight_decay: float, step: int, lr: float, gnorm_scale: float, skip_zeros=False, ) -> None: optim_fns = str2optimizer32bit.get(optimizer_name, None) if optim_fns is None: raise ValueError( f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" ) if g.dtype == torch.float32: optim_func = optim_fns[0] elif g.dtype == torch.float16: optim_func = optim_fns[1] elif g.dtype == torch.bfloat16 and len(optim_fns) == 3: optim_func = optim_fns[2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) with _cuda_device_of(g): optim_func( get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(beta3), ct.c_float(alpha), ct.c_float(eps), ct.c_float(weight_decay), ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()), ) def _optimizer_update_8bit_blockwise_impl( optimizer_name: str, g: torch.Tensor, p: torch.Tensor, state1: torch.Tensor, state2: Optional[torch.Tensor], beta1: float, beta2: float, beta3: float, alpha: float, eps: float, step: int, lr: float, qmap1: torch.Tensor, qmap2: Optional[torch.Tensor], absmax1: torch.Tensor, absmax2: Optional[torch.Tensor], weight_decay: float, gnorm_scale: float, skip_zeros=False, ) -> None: # torch._check( # g.numel() == p.numel(), # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}", # ) # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32] # torch._check( # g.dtype in compute_dtypes, # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}", # ) # torch._check( # g.dtype == p.dtype, # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}", # ) # torch._check( # state1.dtype == torch.uint8, # lambda: f"state1 must be uint8, got {state1.dtype}", # ) # torch._check( # qmap1.dtype == absmax1.dtype == torch.float32, # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}", # ) # if state2 is not None: # torch._check( # state2.dtype == torch.uint8, # lambda: f"state2 must be uint8, got {state2.dtype}", # ) # torch._check( # qmap2.dtype == absmax2.dtype == torch.float32, # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}", # ) optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name) if optimizer_fns is None: raise ValueError( f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" ) if g.dtype == torch.float32: optimizer_fn = optimizer_fns[0] elif g.dtype == torch.float16: optimizer_fn = optimizer_fns[1] elif g.dtype == torch.bfloat16: optimizer_fn = optimizer_fns[2] else: raise ValueError( f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16" ) with _cuda_device_of(g): optimizer_fn( get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(beta3), ct.c_float(alpha), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel()), ) register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl) register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)