| from collections.abc import Sequence |
| import ctypes as ct |
| from math import prod |
| from typing import Optional |
|
|
| import torch |
|
|
| from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr |
|
|
| from ..._ops import register_kernel |
| from ...cextension import ROCM_WARP_SIZE_64, lib |
|
|
|
|
| @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") |
| def _(A: torch.Tensor, B: torch.Tensor): |
| out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) |
| return _int8_linear_matmul_impl(A, B, out) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") |
| def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): |
| _int8_linear_matmul_impl(A, B, out) |
|
|
|
|
| def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): |
| A, B = B, A |
|
|
| shapeA = A.shape |
| shapeB = B.shape |
|
|
| torch._check(A.dtype == torch.int8, lambda: "B must be int8") |
| torch._check(B.dtype == torch.int8, lambda: "A must be int8") |
| torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") |
| torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") |
| torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") |
| torch._check(out.dtype == torch.int32) |
|
|
| shapeC = (*shapeB[:-1], shapeA[0]) |
| torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") |
|
|
| k, m = shapeA |
| n = prod(shapeB[:-1]) |
| lda = shapeA[-1] |
| ldb = shapeB[-1] |
| ldc = shapeC[-1] |
|
|
| torch._check( |
| lda == ldb, |
| lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", |
| ) |
|
|
| |
| |
| |
| if lda % 4 != 0: |
| result = torch.matmul(B.float(), A.float().t()).to(torch.int32) |
| return out.copy_(result) |
|
|
| with _cuda_device_of(A): |
| ctx = CUBLAS_Context.get_instance().get_context(A.device) |
| ptrA = get_ptr(A) |
| ptrB = get_ptr(B) |
| ptrC = get_ptr(out) |
| ptrRowScale = None |
| m = ct.c_int32(m) |
| n = ct.c_int32(n) |
| k = ct.c_int32(k) |
| lda = ct.c_int32(lda) |
| ldb = ct.c_int32(ldb) |
| ldc = ct.c_int32(ldc) |
| stream = _get_tensor_stream(A) |
|
|
| has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) |
|
|
| if has_error: |
| if has_error == 100: |
| |
| |
| raise NotImplementedError("int8_linear_matmul not implemented!") |
| else: |
| raise RuntimeError( |
| f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" |
| ) |
|
|
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") |
| def _( |
| A: torch.Tensor, |
| row_stats: torch.Tensor, |
| col_stats: torch.Tensor, |
| dtype: Optional[torch.dtype] = None, |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") |
| torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") |
| torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") |
|
|
| |
| |
| out = torch.empty_like(A, dtype=torch.float16) |
|
|
| ptrA = get_ptr(A) |
| ptrOut = get_ptr(out) |
| ptrRowStats = get_ptr(row_stats) |
| ptrColStats = get_ptr(col_stats) |
| numRows = ct.c_int32(prod(A.shape[:-1])) |
| numCols = ct.c_int32(A.shape[-1]) |
|
|
| |
| |
| ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None |
|
|
| with _cuda_device_of(A): |
| lib.cdequant_mm_int32_fp16( |
| ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) |
| ) |
|
|
| |
| if bias is not None and bias.dtype != torch.float16: |
| out.add_(bias) |
|
|
| return out.to(dtype or torch.float16) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") |
| def _(A: torch.Tensor, threshold=0.0): |
| torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") |
| torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") |
|
|
| rows = prod(A.shape[:-1]) |
| cols = A.shape[-1] |
|
|
| row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) |
| out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) |
|
|
| outlier_cols = None |
|
|
| if threshold > 0.0: |
| |
| outliers = A.abs() >= threshold |
|
|
| if outliers.any(): |
| outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) |
| else: |
| |
| outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) |
|
|
| with _cuda_device_of(A): |
| lib.cint8_vector_quant( |
| get_ptr(A), |
| get_ptr(out_row), |
| get_ptr(row_stats), |
| ct.c_float(threshold), |
| ct.c_int32(rows), |
| ct.c_int32(cols), |
| _get_tensor_stream(A), |
| ) |
|
|
| |
| |
| if rows > 1 and outlier_cols is not None: |
| out_row[:, outlier_cols] = 0 |
|
|
| return out_row, row_stats, outlier_cols |
|
|
|
|
| @register_kernel("bitsandbytes::int8_double_quant", "cuda") |
| def _( |
| A: torch.Tensor, |
| threshold=0.0, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| |
| quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( |
| A, |
| threshold=threshold, |
| ) |
|
|
| |
| col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) |
| if threshold > 0.0 and outlier_mask is not None: |
| A = A.masked_fill(outlier_mask, 0.0) |
| quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) |
|
|
| return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols |
|
|
|
|
| def _get_col_absmax( |
| A: torch.Tensor, |
| threshold=0.0, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| torch._check(A.is_floating_point()) |
|
|
| outlier_mask = None |
|
|
| absA = A.abs().view(-1, A.shape[-1]) |
|
|
| if threshold > 0.0: |
| |
| outlier_mask = absA >= threshold |
| absA.masked_fill_(outlier_mask, 0.0) |
|
|
| |
| col_stats = absA.amax(dim=0, keepdim=False).float() |
|
|
| return col_stats, outlier_mask |
|
|
|
|
| @register_kernel("bitsandbytes::quantize_blockwise", "cuda") |
| def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: |
| A = A.contiguous() |
| torch._check_is_size(blocksize) |
|
|
| if ROCM_WARP_SIZE_64: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
| else: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) |
|
|
| torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") |
|
|
| n = A.numel() |
| blocks = -(n // -blocksize) |
| absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) |
| out = torch.empty_like(A, dtype=torch.uint8) |
|
|
| with _cuda_device_of(A): |
| args = ( |
| get_ptr(code), |
| get_ptr(A), |
| get_ptr(absmax), |
| get_ptr(out), |
| ct.c_int32(blocksize), |
| ct.c_int(A.numel()), |
| ) |
|
|
| if A.dtype == torch.float16: |
| lib.cquantize_blockwise_fp16(*args) |
| elif A.dtype == torch.bfloat16: |
| lib.cquantize_blockwise_bf16(*args) |
| elif A.dtype == torch.float32: |
| lib.cquantize_blockwise_fp32(*args) |
| else: |
| raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") |
|
|
| return out, absmax |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_blockwise", "cuda") |
| def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: |
| out = torch.empty_like(A, dtype=dtype) |
| _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) |
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") |
| def _( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| code: torch.Tensor, |
| blocksize: int, |
| dtype: torch.dtype, |
| out: torch.Tensor, |
| ) -> None: |
| torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") |
| torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") |
| _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) |
|
|
|
|
| def _dequantize_blockwise_impl( |
| A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor |
| ) -> None: |
| A = A.contiguous() |
| if ROCM_WARP_SIZE_64: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
| else: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) |
|
|
| torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") |
| torch._check( |
| dtype in [torch.float16, torch.bfloat16, torch.float32], |
| lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", |
| ) |
|
|
| with _cuda_device_of(A): |
| args = ( |
| get_ptr(code), |
| get_ptr(A), |
| get_ptr(absmax), |
| get_ptr(out), |
| ct.c_int(blocksize), |
| ct.c_int(A.numel()), |
| _get_tensor_stream(A), |
| ) |
|
|
| if dtype == torch.float16: |
| lib.cdequantize_blockwise_fp16(*args) |
| elif dtype == torch.bfloat16: |
| lib.cdequantize_blockwise_bf16(*args) |
| elif dtype == torch.float32: |
| lib.cdequantize_blockwise_fp32(*args) |
|
|
|
|
| @register_kernel("bitsandbytes::quantize_4bit", "cuda") |
| def _( |
| A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| A = A.contiguous() |
| if ROCM_WARP_SIZE_64: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
| else: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) |
|
|
| torch._check(quant_type in ["fp4", "nf4"]) |
| torch._check( |
| A.dtype in [torch.bfloat16, torch.float16, torch.float32], |
| lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", |
| ) |
|
|
| n = A.numel() |
| blocks = -(n // -blocksize) |
| absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) |
| out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) |
|
|
| with _cuda_device_of(A): |
| args = ( |
| None, |
| get_ptr(A), |
| get_ptr(absmax), |
| get_ptr(out), |
| ct.c_int32(blocksize), |
| ct.c_int32(n), |
| ) |
|
|
| if A.dtype == torch.bfloat16: |
| if quant_type == "fp4": |
| lib.cquantize_blockwise_bf16_fp4(*args) |
| else: |
| lib.cquantize_blockwise_bf16_nf4(*args) |
| elif A.dtype == torch.float16: |
| if quant_type == "fp4": |
| lib.cquantize_blockwise_fp16_fp4(*args) |
| else: |
| lib.cquantize_blockwise_fp16_nf4(*args) |
| elif A.dtype == torch.float32: |
| if quant_type == "fp4": |
| lib.cquantize_blockwise_fp32_fp4(*args) |
| else: |
| lib.cquantize_blockwise_fp32_nf4(*args) |
|
|
| return out, absmax |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_4bit", "cuda") |
| def _( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| blocksize: int, |
| quant_type: str, |
| shape: Sequence[int], |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| out = torch.empty(shape, dtype=dtype, device=A.device) |
| _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) |
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") |
| def _( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| blocksize: int, |
| quant_type: str, |
| shape: Sequence[int], |
| dtype: torch.dtype, |
| out: torch.Tensor, |
| ) -> None: |
| torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") |
| torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") |
| _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) |
|
|
|
|
| def _dequantize_4bit_impl( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| blocksize: int, |
| quant_type: str, |
| dtype: torch.dtype, |
| out: torch.Tensor, |
| ) -> None: |
| A = A.contiguous() |
| if ROCM_WARP_SIZE_64: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) |
| else: |
| torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) |
|
|
| torch._check(quant_type in ["fp4", "nf4"]) |
| torch._check( |
| dtype in [torch.bfloat16, torch.float16, torch.float32], |
| lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", |
| ) |
|
|
| with _cuda_device_of(A): |
| args = ( |
| None, |
| get_ptr(A), |
| get_ptr(absmax), |
| get_ptr(out), |
| ct.c_int(blocksize), |
| ct.c_int32(out.numel()), |
| _get_tensor_stream(A), |
| ) |
|
|
| if out.dtype == torch.bfloat16: |
| if quant_type == "fp4": |
| lib.cdequantize_blockwise_bf16_fp4(*args) |
| else: |
| lib.cdequantize_blockwise_bf16_nf4(*args) |
| elif out.dtype == torch.float16: |
| if quant_type == "fp4": |
| lib.cdequantize_blockwise_fp16_fp4(*args) |
| else: |
| lib.cdequantize_blockwise_fp16_nf4(*args) |
| elif out.dtype == torch.float32: |
| if quant_type == "fp4": |
| lib.cdequantize_blockwise_fp32_fp4(*args) |
| else: |
| lib.cdequantize_blockwise_fp32_nf4(*args) |
|
|
|
|
| @register_kernel("bitsandbytes::gemv_4bit", "cuda") |
| def _( |
| A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int |
| ) -> torch.Tensor: |
| shape = (*A.shape[:-1], shapeB[0]) |
| out = torch.empty(shape, device=A.device, dtype=A.dtype) |
| _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) |
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::gemv_4bit.out", "cuda") |
| def _( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| shapeB: Sequence[int], |
| absmax: torch.Tensor, |
| code: torch.Tensor, |
| blocksize: int, |
| out: torch.Tensor, |
| ) -> None: |
| torch._check( |
| out.shape == (*A.shape[:-1], shapeB[0]), |
| lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", |
| ) |
| torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") |
| _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) |
|
|
|
|
| def _gemv_4bit_impl( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| shapeB: Sequence[int], |
| absmax: torch.Tensor, |
| code: torch.Tensor, |
| blocksize: int, |
| out: torch.Tensor, |
| ) -> None: |
| torch._check_is_size(blocksize) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| m = ct.c_int32(shapeB[0]) |
| n = ct.c_int32(1) |
| k = ct.c_int32(shapeB[1]) |
|
|
| lda = m |
| ldb = ct.c_int32((A.shape[-1] + 1) // 2) |
| ldc = m |
|
|
| stream = _get_tensor_stream(A) |
|
|
| with _cuda_device_of(A): |
| if A.dtype == torch.float16: |
| lib.cgemm_4bit_inference_naive_fp16( |
| m, |
| n, |
| k, |
| get_ptr(A), |
| get_ptr(B), |
| get_ptr(absmax), |
| get_ptr(code), |
| get_ptr(out), |
| lda, |
| ldb, |
| ldc, |
| ct.c_int32(blocksize), |
| stream, |
| ) |
| elif A.dtype == torch.bfloat16: |
| lib.cgemm_4bit_inference_naive_bf16( |
| m, |
| n, |
| k, |
| get_ptr(A), |
| get_ptr(B), |
| get_ptr(absmax), |
| get_ptr(code), |
| get_ptr(out), |
| lda, |
| ldb, |
| ldc, |
| ct.c_int32(blocksize), |
| stream, |
| ) |
| elif A.dtype == torch.float32: |
| lib.cgemm_4bit_inference_naive_fp32( |
| m, |
| n, |
| k, |
| get_ptr(A), |
| get_ptr(B), |
| get_ptr(absmax), |
| get_ptr(code), |
| get_ptr(out), |
| lda, |
| ldb, |
| ldc, |
| ct.c_int32(blocksize), |
| stream, |
| ) |
|
|
|
|
| """C FUNCTIONS FOR OPTIMIZERS""" |
| str2optimizer32bit = { |
| "adam": ( |
| lib.cadam32bit_grad_fp32, |
| lib.cadam32bit_grad_fp16, |
| lib.cadam32bit_grad_bf16, |
| ), |
| "momentum": ( |
| lib.cmomentum32bit_grad_32, |
| lib.cmomentum32bit_grad_16, |
| ), |
| "rmsprop": ( |
| lib.crmsprop32bit_grad_32, |
| lib.crmsprop32bit_grad_16, |
| ), |
| "lion": ( |
| lib.clion32bit_grad_fp32, |
| lib.clion32bit_grad_fp16, |
| lib.clion32bit_grad_bf16, |
| ), |
| "adagrad": ( |
| lib.cadagrad32bit_grad_32, |
| lib.cadagrad32bit_grad_16, |
| ), |
| "lamb": ( |
| lib.cadam32bit_grad_fp32, |
| lib.cadam32bit_grad_fp16, |
| lib.cadam32bit_grad_bf16, |
| ), |
| "ademamix": ( |
| lib.cademamix32bit_grad_fp32, |
| lib.cademamix32bit_grad_fp16, |
| lib.cademamix32bit_grad_bf16, |
| ), |
| } |
|
|
| str2optimizer8bit_blockwise = { |
| "adam": ( |
| lib.cadam_8bit_blockwise_grad_fp32, |
| lib.cadam_8bit_blockwise_grad_fp16, |
| lib.cadam_8bit_blockwise_grad_bf16, |
| ), |
| "momentum": ( |
| lib.cmomentum_8bit_blockwise_grad_fp32, |
| lib.cmomentum_8bit_blockwise_grad_fp16, |
| lib.cmomentum_8bit_blockwise_grad_bf16, |
| ), |
| "rmsprop": ( |
| lib.crmsprop_8bit_blockwise_grad_fp32, |
| lib.crmsprop_8bit_blockwise_grad_fp16, |
| lib.crmsprop_8bit_blockwise_grad_bf16, |
| ), |
| "lion": ( |
| lib.clion_8bit_blockwise_grad_fp32, |
| lib.clion_8bit_blockwise_grad_fp16, |
| lib.clion_8bit_blockwise_grad_bf16, |
| ), |
| "adagrad": ( |
| lib.cadagrad_8bit_blockwise_grad_fp32, |
| lib.cadagrad_8bit_blockwise_grad_fp16, |
| lib.cadagrad_8bit_blockwise_grad_bf16, |
| ), |
| "ademamix": ( |
| lib.cademamix_8bit_blockwise_grad_fp32, |
| lib.cademamix_8bit_blockwise_grad_fp16, |
| lib.cademamix_8bit_blockwise_grad_bf16, |
| ), |
| } |
|
|
|
|
| def _optimizer_update_32bit_impl( |
| optimizer_name: str, |
| g: torch.Tensor, |
| p: torch.Tensor, |
| state1: torch.Tensor, |
| state2: Optional[torch.Tensor], |
| unorm_vec: Optional[torch.Tensor], |
| max_unorm: float, |
| param_norm: float, |
| beta1: float, |
| beta2: float, |
| beta3: float, |
| alpha: float, |
| eps: float, |
| weight_decay: float, |
| step: int, |
| lr: float, |
| gnorm_scale: float, |
| skip_zeros=False, |
| ) -> None: |
| optim_fns = str2optimizer32bit.get(optimizer_name, None) |
| if optim_fns is None: |
| raise ValueError( |
| f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" |
| ) |
| if g.dtype == torch.float32: |
| optim_func = optim_fns[0] |
| elif g.dtype == torch.float16: |
| optim_func = optim_fns[1] |
| elif g.dtype == torch.bfloat16 and len(optim_fns) == 3: |
| optim_func = optim_fns[2] |
| else: |
| raise ValueError( |
| f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", |
| ) |
|
|
| with _cuda_device_of(g): |
| optim_func( |
| get_ptr(g), |
| get_ptr(p), |
| get_ptr(state1), |
| get_ptr(state2), |
| get_ptr(unorm_vec), |
| ct.c_float(max_unorm), |
| ct.c_float(param_norm), |
| ct.c_float(beta1), |
| ct.c_float(beta2), |
| ct.c_float(beta3), |
| ct.c_float(alpha), |
| ct.c_float(eps), |
| ct.c_float(weight_decay), |
| ct.c_int32(step), |
| ct.c_float(lr), |
| ct.c_float(gnorm_scale), |
| ct.c_bool(skip_zeros), |
| ct.c_int32(g.numel()), |
| ) |
|
|
|
|
| def _optimizer_update_8bit_blockwise_impl( |
| optimizer_name: str, |
| g: torch.Tensor, |
| p: torch.Tensor, |
| state1: torch.Tensor, |
| state2: Optional[torch.Tensor], |
| beta1: float, |
| beta2: float, |
| beta3: float, |
| alpha: float, |
| eps: float, |
| step: int, |
| lr: float, |
| qmap1: torch.Tensor, |
| qmap2: Optional[torch.Tensor], |
| absmax1: torch.Tensor, |
| absmax2: Optional[torch.Tensor], |
| weight_decay: float, |
| gnorm_scale: float, |
| skip_zeros=False, |
| ) -> None: |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name) |
| if optimizer_fns is None: |
| raise ValueError( |
| f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}" |
| ) |
|
|
| if g.dtype == torch.float32: |
| optimizer_fn = optimizer_fns[0] |
| elif g.dtype == torch.float16: |
| optimizer_fn = optimizer_fns[1] |
| elif g.dtype == torch.bfloat16: |
| optimizer_fn = optimizer_fns[2] |
| else: |
| raise ValueError( |
| f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16" |
| ) |
|
|
| with _cuda_device_of(g): |
| optimizer_fn( |
| get_ptr(p), |
| get_ptr(g), |
| get_ptr(state1), |
| get_ptr(state2), |
| ct.c_float(beta1), |
| ct.c_float(beta2), |
| ct.c_float(beta3), |
| ct.c_float(alpha), |
| ct.c_float(eps), |
| ct.c_int32(step), |
| ct.c_float(lr), |
| get_ptr(qmap1), |
| get_ptr(qmap2), |
| get_ptr(absmax1), |
| get_ptr(absmax2), |
| ct.c_float(weight_decay), |
| ct.c_float(gnorm_scale), |
| ct.c_bool(skip_zeros), |
| ct.c_int32(g.numel()), |
| ) |
|
|
|
|
| register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl) |
| register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl) |
|
|