| from collections.abc import Sequence |
| from functools import wraps |
| from math import prod, sqrt |
| from typing import Optional |
|
|
| import torch |
|
|
| from ..._ops import register_kernel |
| from ..utils import CODE |
|
|
|
|
| def _try_torch_compile(func=None, **compile_kwargs): |
| """ |
| Wrapper around torch.compile that falls back to the original function if compilation fails. |
| """ |
|
|
| def decorator(fn): |
| try: |
| compiled_fn = torch.compile(fn, **compile_kwargs) |
|
|
| @wraps(fn) |
| def wrapper(*args, **kwargs): |
| try: |
| return compiled_fn(*args, **kwargs) |
| except Exception: |
| return fn(*args, **kwargs) |
|
|
| return wrapper |
| except Exception: |
| return fn |
|
|
| if func is None: |
| return decorator |
| else: |
| return decorator(func) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_mm_dequant", "default") |
| def _( |
| A: torch.Tensor, |
| row_stats: torch.Tensor, |
| col_stats: torch.Tensor, |
| dtype: Optional[torch.dtype] = None, |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") |
| torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") |
| torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") |
|
|
| A_calc = A.view(-1, A.shape[-1]) |
| row_stats = row_stats.reshape(-1).unsqueeze(-1) |
| col_stats = col_stats.reshape(-1).unsqueeze(0) |
|
|
| out = A_calc * (row_stats * col_stats) * 6.200124e-05 |
| if bias is not None: |
| out += bias |
|
|
| return out.to(dtype or torch.float16) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default") |
| def _( |
| A: torch.Tensor, |
| CA: torch.Tensor, |
| CB: torch.Tensor, |
| SCA: torch.Tensor, |
| SCB: torch.Tensor, |
| outlier_cols: Optional[torch.Tensor] = None, |
| bias: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| subB = None |
|
|
| if outlier_cols is not None and outlier_cols.numel(): |
| |
| subA = A[:, outlier_cols].contiguous() |
|
|
| |
| subB = ( |
| torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB) |
| .to(A.dtype) |
| .t() |
| ) |
|
|
| |
|
|
| else: |
| |
| subA = torch.empty(0, device=A.device, dtype=A.dtype) |
|
|
| |
| output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype) |
|
|
| if subB is not None: |
| |
| output = output.addmm(subA, subB) |
|
|
| return output, subA |
|
|
|
|
| @register_kernel("bitsandbytes::int8_scaled_mm", "default") |
| def _( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| row_stats: torch.Tensor, |
| col_stats: torch.Tensor, |
| bias: Optional[torch.Tensor] = None, |
| dtype: Optional[torch.dtype] = None, |
| ) -> torch.Tensor: |
| out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) |
| return torch.ops.bitsandbytes.int8_mm_dequant.default( |
| out_i32, |
| row_stats, |
| col_stats, |
| dtype=dtype or torch.float16, |
| bias=bias, |
| ) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_linear_matmul", "default") |
| def _(A: torch.Tensor, B: torch.Tensor): |
| return _int8_linear_matmul_impl(A, B) |
|
|
|
|
| @register_kernel("bitsandbytes::int8_linear_matmul.out", "default") |
| def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): |
| torch._check(out.dtype == torch.int32) |
| _int8_linear_matmul_impl(A, B, out) |
|
|
|
|
| def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None): |
| |
| result = torch.matmul(A.float(), B.float().t()).to(torch.int32) |
| if out is not None: |
| result = out.copy_(result) |
| return result |
|
|
|
|
| @register_kernel("bitsandbytes::int8_vectorwise_quant", "default") |
| def _(A: torch.Tensor, threshold=0.0): |
| rows = prod(A.shape[:-1]) |
| outlier_cols = None |
|
|
| outlier_restore = None |
|
|
| if threshold > 0.0: |
| outliers = A.abs() >= threshold |
|
|
| if outliers.any(): |
| |
| |
| |
| outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) |
| outlier_restore = A[outliers].clone() |
| A[outliers] = 0 |
| else: |
| |
| outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) |
|
|
| |
| row_stats = torch.max(A.abs(), dim=1).values.float() |
|
|
| |
| out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8) |
|
|
| |
| if rows > 1 and outlier_cols is not None: |
| out_row[:, outlier_cols] = 0 |
|
|
| |
| if outlier_restore is not None: |
| A[outliers] = outlier_restore |
|
|
| return out_row, row_stats, outlier_cols |
|
|
|
|
| @register_kernel("bitsandbytes::quantize_blockwise", "default") |
| def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: |
| torch._check_is_size(blocksize) |
|
|
| n = A.numel() |
| rem = n % blocksize |
| has_rem = rem > 0 |
| blocks = n // blocksize + has_rem |
| absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) |
| A_reshaped = A.reshape(n) |
| A_com = A_reshaped[: n - rem] |
| A_com_reshaped = A_com.reshape(n // blocksize, blocksize) |
| absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] |
| scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) |
| scaled_A = scaled_A.reshape(-1) |
| if has_rem: |
| absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() |
| scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) |
| scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) |
|
|
| diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) |
| out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) |
|
|
| return out, absmax |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_blockwise", "default") |
| def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: |
| torch._check_is_size(blocksize) |
| torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") |
|
|
| out = code[A.reshape(-1).int()] |
| blocks = out.shape[-1] // blocksize |
| res = out.shape[-1] % blocksize |
| if res != 0: |
| out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) |
| out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) |
| out = out[: blocks * blocksize + res] |
| out = out.reshape(A.shape) |
|
|
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::quantize_4bit", "default") |
| def _( |
| A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| torch._check_is_size(blocksize) |
| torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") |
| torch._check( |
| A.dtype in [torch.bfloat16, torch.float16, torch.float32], |
| lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", |
| ) |
|
|
| n = A.numel() |
| full_blocks = n // blocksize |
| rem = n % blocksize |
| blocks = full_blocks + 1 if rem else full_blocks |
| absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) |
| A_flattened = A.reshape(n) |
|
|
| |
| A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize) |
| absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0] |
| scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1) |
|
|
| |
| if rem: |
| A_rem = A_flattened[-rem:] |
| absmax[-1] = torch.abs(A_rem).max() |
| scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1) |
| scaled = torch.cat([scaled, scaled_rem], dim=0) |
|
|
| |
| code = CODE[quant_type].to(scaled.device).to(scaled.dtype) |
| quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8) |
|
|
| |
| packed = quantized[::2] << 4 | quantized[1::2] |
|
|
| if quant_storage != torch.uint8: |
| packed = packed.squeeze().view(quant_storage).unsqueeze(1) |
|
|
| return packed, absmax.float() |
|
|
|
|
| def _dequantize_4bit_impl( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| blocksize: int, |
| quant_type: str, |
| shape: Sequence[int], |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| |
| if A.dtype != torch.uint8: |
| A = A.view(torch.uint8) |
|
|
| A = A.reshape(-1) |
| |
| out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) |
| n = out_dq.numel() |
| out_dq[1::2] = A & 0xF |
| out_dq[::2] = A >> 4 |
| |
| code = CODE[quant_type].to(dtype).to(A.device) |
| out_dq = code[out_dq] |
|
|
| |
| if out_dq.numel() != n: |
| assert out_dq.numel() == n + 1 |
| out_dq = torch.narrow(out_dq, 0, 0, n) |
| blocks = n // blocksize |
| blocks += 1 if n % blocksize > 0 else 0 |
| rem = n % blocksize |
| has_rem = rem > 0 |
|
|
| out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) |
| if has_rem: |
| out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) |
| out[n - rem :] = out_dq[n - rem :] * absmax[-1] |
| else: |
| out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) |
|
|
| out = out.reshape(-1, *shape[1:]).to(dtype) |
|
|
| return out |
|
|
|
|
| @register_kernel("bitsandbytes::dequantize_4bit", "default") |
| def _( |
| A: torch.Tensor, |
| absmax: torch.Tensor, |
| blocksize: int, |
| quant_type: str, |
| shape: Sequence[int], |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| torch._check_is_size(blocksize) |
| torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") |
| torch._check( |
| dtype in [torch.bfloat16, torch.float16, torch.float32], |
| lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", |
| ) |
|
|
| return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) |
|
|
|
|
| @register_kernel("bitsandbytes::gemv_4bit", "default") |
| def _( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| shapeB: Sequence[int], |
| absmax: torch.Tensor, |
| code: torch.Tensor, |
| blocksize: int, |
| ) -> torch.Tensor: |
| |
| quant_type = "fp4" if code[1] > 0 else "nf4" |
| B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype) |
|
|
| return torch.nn.functional.linear( |
| A, |
| B_dq, |
| bias=None, |
| ) |
|
|
|
|
| MOMENTUM = 0 |
| RMSPROP = 1 |
| ADAGRAD = 2 |
| ADAM = 3 |
| |
| LION = 4 |
| ADEMAMIX = 5 |
|
|
| name2optimizer_id = { |
| "momentum": MOMENTUM, |
| "rmsprop": RMSPROP, |
| "adagrad": ADAGRAD, |
| "adam": ADAM, |
| "lion": LION, |
| "ademamix": ADEMAMIX, |
| } |
|
|
|
|
| @_try_torch_compile |
| def _optimizer_precondition_32bit( |
| g: torch.Tensor, |
| p: torch.Tensor, |
| state1: torch.Tensor, |
| state2: Optional[torch.Tensor], |
| unorm_vec: torch.Tensor, |
| beta1: float, |
| beta2: float, |
| eps: float, |
| weight_decay: float, |
| step: int, |
| lr: float, |
| gnorm_scale: float, |
| optimizer_id: int, |
| ): |
| """Preprocessing optimizer, computing update norm""" |
|
|
| g_vals = gnorm_scale * g |
|
|
| if optimizer_id == 3: |
| correction1 = 1.0 / (1.0 - beta1**step) |
| correction2 = 1.0 / (1.0 - beta2**step) |
|
|
| s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals |
| s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals |
|
|
| s1_vals = s1_vals * correction1 |
| s2_vals = s2_vals * correction2 |
|
|
| update_vals = s1_vals / (torch.sqrt(s2_vals) + eps) |
| update_norm = update_vals * update_vals |
|
|
| elif optimizer_id == 5: |
| update_norm = state1 |
|
|
| elif optimizer_id == 0: |
| if step == 1: |
| s1_vals = g_vals |
| else: |
| s1_vals = state1 * beta1 + g_vals |
| update_norm = s1_vals * s1_vals |
|
|
| elif optimizer_id == 4: |
| s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals |
| update_norm = s1_vals |
|
|
| elif optimizer_id == 1: |
| s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals |
| update_vals = g_vals / (torch.sqrt(s1_vals) + eps) |
| update_norm = update_vals * update_vals |
|
|
| elif optimizer_id == 2: |
| s1_vals = state1 + g_vals * g_vals |
| update_vals = g_vals / (torch.sqrt(s1_vals) + eps) |
| update_norm = update_vals * update_vals |
|
|
| total_norm = torch.sum(update_norm) |
| unorm_vec.add_(total_norm) |
|
|
|
|
| @_try_torch_compile |
| def _optimizer_update_32bit( |
| g: torch.Tensor, |
| p: torch.Tensor, |
| state1: torch.Tensor, |
| state2: Optional[torch.Tensor], |
| unorm_vec: Optional[torch.Tensor], |
| max_unorm: float, |
| param_norm: float, |
| beta1: float, |
| beta2: float, |
| beta3: float, |
| alpha: float, |
| eps: float, |
| weight_decay: float, |
| step: int, |
| lr: float, |
| gnorm_scale: float, |
| optimizer_id: int, |
| ): |
| """Unified optimizer update kernel""" |
|
|
| p_vals = p.float() |
| g_vals = (gnorm_scale * g).float() |
| if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0: |
| g_vals = g_vals + p_vals * weight_decay |
|
|
| update_scale = 1.0 |
| if max_unorm > 0.0: |
| current_unorm = torch.sqrt(unorm_vec) |
| if optimizer_id in [0, 1, 2, 4]: |
| if current_unorm > max_unorm * param_norm + eps: |
| update_scale = (max_unorm * param_norm + eps) / current_unorm |
| else: |
| if current_unorm > max_unorm * param_norm: |
| update_scale = (max_unorm * param_norm) / current_unorm |
|
|
| if optimizer_id == 3: |
| s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals |
| s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals |
|
|
| correction1 = 1.0 - beta1**step |
| correction2 = sqrt(1.0 - beta2**step) |
| step_size = -lr * correction2 / correction1 |
|
|
| if weight_decay > 0.0: |
| p_vals = p_vals * (1.0 - lr * weight_decay) |
|
|
| update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2)) |
| p_vals = p_vals + update_val |
|
|
| state1.copy_(s1_vals) |
| state2.copy_(s2_vals) |
|
|
| elif optimizer_id == 5: |
| s1_vals = state1[0] |
| s3_vals = state1[1] |
| s2_vals = state2 |
|
|
| m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals |
| m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals |
| nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals |
|
|
| correction1 = 1.0 - beta1**step |
| correction2 = sqrt(1.0 - beta2**step) |
|
|
| if weight_decay > 0.0: |
| p_vals = p_vals * (1.0 - lr * weight_decay) |
|
|
| mixed_momentum = (m1 / correction1) + (alpha * m2) |
| adaptive_term = (torch.sqrt(nu) / correction2) + eps |
| p_vals = p_vals - lr * (mixed_momentum / adaptive_term) |
|
|
| state1[0].copy_(m1) |
| state1[1].copy_(m2) |
| state2.copy_(nu) |
|
|
| elif optimizer_id == 0: |
| if step == 1: |
| s1_vals = g_vals |
| else: |
| s1_vals = state1 * beta1 + g_vals |
|
|
| update_val = update_scale * (-lr * s1_vals) |
| p_vals = p_vals + update_val |
|
|
| state1.copy_(s1_vals) |
|
|
| elif optimizer_id == 4: |
| momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals |
| update_val = update_scale * lr * torch.sign(momentum_update) |
| p_vals = p_vals - update_val |
|
|
| s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals |
| state1.copy_(s1_vals) |
|
|
| elif optimizer_id == 1: |
| s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals |
| update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps) |
| p_vals = p_vals - update_val |
|
|
| state1.copy_(s1_vals) |
|
|
| elif optimizer_id == 2: |
| s1_vals = state1 + g_vals * g_vals |
| update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps) |
| p_vals = p_vals - update_val |
|
|
| state1.copy_(s1_vals) |
|
|
| p.copy_(p_vals) |
|
|
|
|
| @register_kernel("bitsandbytes::optimizer_update_32bit", "default") |
| def _( |
| optimizer_name: str, |
| g: torch.Tensor, |
| p: torch.Tensor, |
| state1: torch.Tensor, |
| state2: Optional[torch.Tensor], |
| unorm_vec: Optional[torch.Tensor], |
| max_unorm: float, |
| param_norm: float, |
| beta1: float, |
| beta2: float, |
| beta3: float, |
| alpha: float, |
| eps: float, |
| weight_decay: float, |
| step: int, |
| lr: float, |
| gnorm_scale: float = 1.0, |
| skip_zeros=False, |
| ) -> None: |
| """ |
| 32-bit optimizer implemented by PyTorch with @torch.compile |
| """ |
| if skip_zeros: |
| raise NotImplementedError("skip_zeros is not supported yet") |
|
|
| optimizer_id = name2optimizer_id[optimizer_name] |
|
|
| if optimizer_name == "lion": |
| _optimizer_update_32bit( |
| g, |
| p, |
| state1, |
| state2, |
| unorm_vec, |
| max_unorm, |
| param_norm, |
| beta1, |
| beta2, |
| beta3, |
| alpha, |
| eps, |
| weight_decay, |
| step, |
| lr, |
| gnorm_scale, |
| optimizer_id, |
| ) |
|
|
| if max_unorm > 0.0: |
| unorm_vec.zero_() |
| _optimizer_precondition_32bit( |
| g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id |
| ) |
| else: |
| if max_unorm > 0.0: |
| unorm_vec.zero_() |
| _optimizer_precondition_32bit( |
| g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id |
| ) |
|
|
| _optimizer_update_32bit( |
| g, |
| p, |
| state1, |
| state2, |
| unorm_vec, |
| max_unorm, |
| param_norm, |
| beta1, |
| beta2, |
| beta3, |
| alpha, |
| eps, |
| weight_decay, |
| step, |
| lr, |
| gnorm_scale, |
| optimizer_id, |
| ) |
|
|