File size: 14,025 Bytes
dc9bb20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 | from dataclasses import dataclass
from math import prod
from typing import Optional
import warnings
from warnings import warn
import torch
import bitsandbytes.functional as F
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
"""
class GlobalOutlierPooler:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.outliers = set()
self.model_dim = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None:
self.model_dim = feature_dim
if feature_dim != self.model_dim:
return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
_is_compiling = torch.compiler.is_compiling
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None # TODO: remove
force_no_igemmlt: bool = False
CB: Optional[torch.Tensor] = None
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None
subB: Optional[torch.Tensor] = None
outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False
threshold = 0.0
idx: Optional[torch.Tensor] = None
is_training = True
has_fp16_weights = True
use_pool = False
formatB = "row" # TODO: Deprecate/remove
def reset_grads(self):
self.CB = None
self.CxB = None
self.SB = None
self.SCB = None
self.CxBt = None
self.SBt = None
self.CBt = None
@property
def tile_indices(self):
raise ValueError("tile_indices is no longer supported.")
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
):
state = state or MatmulLtState()
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
input_shape = A.shape
# Cast A to fp16
if A.dtype != torch.float16 and not _is_compiling():
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
else:
# Fast path
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CAt = SCAt = None
has_grad = False
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
# Handle sparse decomposition
if state.threshold > 0.0:
state.idx = outlier_cols
# Mixed Int8 Matmul + Dequant + Bias
output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
A,
CA,
state.CB,
SCA,
state.SCB,
outlier_cols,
bias,
)
else:
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
)
subA = None
# 5. Save state
ctx.state = state
ctx.grad_shape = input_shape
ctx.dtype_A = A.dtype
ctx.dtype_bias = None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
output_shape = (*input_shape[:-1], state.CB.shape[0])
if len(input_shape) == 3:
return output.reshape(output_shape)
return output
@staticmethod
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, _A = ctx.tensors
SCAt, idx = ctx.tensor_states
state: MatmulLtState = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)
if state.threshold > 0.0 and subA is not None and subA.numel() > 0:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU MatMul8bitFp is much faster (~3x) than MatMul8bitLt in finetune.
# Because the MatMul8bitLt has more mechanisms in computing grad.
# We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow.
# We'd like to use dequant + matmul to run finetune with good performance.
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
B = state.CB
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
ctx.A = A
ctx.dtype_bias = None if bias is None else bias.dtype
return output
@staticmethod
def backward(ctx, grad_output):
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
A = ctx.A
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
if req_gradB:
grad_B = torch.matmul(A.t(), grad_output).t()
if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = quant_state.shape
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = quant_state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (None, B)
else:
ctx.tensors = (None, None)
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad
_, B = ctx.tensors
grad_A, grad_B, grad_bias = None, None, None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
# not supported by PyTorch. TODO: create work-around
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
return grad_A, grad_B, None, grad_bias, None
def matmul(
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias: Optional[torch.Tensor] = None,
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
# Change dtype to input dtype on CPU
if A.device.type == "cpu":
quant_state.dtype = A.dtype
if getattr(quant_state, "packing_format_for_cpu", False):
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|