build-tools / bitsandbytes /autograd /_functions.py
salmankhanpm's picture
Add files using upload-large-folder tool
dc9bb20 verified
from dataclasses import dataclass
from math import prod
from typing import Optional
import warnings
from warnings import warn
import torch
import bitsandbytes.functional as F
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
class GlobalOutlierPooler:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
self.model_dim = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None:
self.model_dim = feature_dim
if feature_dim != self.model_dim:
return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
_is_compiling = torch.compiler.is_compiling
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
force_no_igemmlt: bool = False
CB: Optional[torch.Tensor] = None
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None
subB: Optional[torch.Tensor] = None
outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False
threshold = 0.0
idx: Optional[torch.Tensor] = None
is_training = True
has_fp16_weights = True
use_pool = False
formatB = "row" # TODO: Deprecate/remove
def reset_grads(self):
self.CB = None
self.CxB = None
self.SB = None
self.SCB = None
self.CxBt = None
self.SBt = None
self.CBt = None
@property
def tile_indices(self):
raise ValueError("tile_indices is no longer supported.")
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
):
state = state or MatmulLtState()
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
input_shape = A.shape
# Cast A to fp16
if A.dtype != torch.float16 and not _is_compiling():
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
else:
# Fast path
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CAt = SCAt = None
has_grad = False
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
# Handle sparse decomposition
if state.threshold > 0.0:
state.idx = outlier_cols
# Mixed Int8 Matmul + Dequant + Bias
output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
A,
CA,
state.CB,
SCA,
state.SCB,
outlier_cols,
bias,
)
else:
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
)
subA = None
# 5. Save state
ctx.state = state
ctx.grad_shape = input_shape
ctx.dtype_A = A.dtype
ctx.dtype_bias = None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
output_shape = (*input_shape[:-1], state.CB.shape[0])
if len(input_shape) == 3:
return output.reshape(output_shape)
return output
@staticmethod
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, _A = ctx.tensors
SCAt, idx = ctx.tensor_states
state: MatmulLtState = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU MatMul8bitFp is much faster (~3x) than MatMul8bitLt in finetune.
# Because the MatMul8bitLt has more mechanisms in computing grad.
# We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow.
# We'd like to use dequant + matmul to run finetune with good performance.
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
B = state.CB
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
ctx.A = A
ctx.dtype_bias = None if bias is None else bias.dtype
return output
@staticmethod
def backward(ctx, grad_output):
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
A = ctx.A
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
if req_gradB:
grad_B = torch.matmul(A.t(), grad_output).t()
if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = quant_state.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (None, B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
_, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None
def matmul(
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias: Optional[torch.Tensor] = None,
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# Change dtype to input dtype on CPU
if A.device.type == "cpu":
quant_state.dtype = A.dtype
if getattr(quant_state, "packing_format_for_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)