File size: 18,896 Bytes
dc9bb20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 | from collections.abc import Sequence
from functools import wraps
from math import prod, sqrt
from typing import Optional
import torch
from ..._ops import register_kernel
from ..utils import CODE
def _try_torch_compile(func=None, **compile_kwargs):
"""
Wrapper around torch.compile that falls back to the original function if compilation fails.
"""
def decorator(fn):
try:
compiled_fn = torch.compile(fn, **compile_kwargs)
@wraps(fn)
def wrapper(*args, **kwargs):
try:
return compiled_fn(*args, **kwargs)
except Exception:
return fn(*args, **kwargs)
return wrapper
except Exception:
return fn
if func is None:
return decorator
else:
return decorator(func)
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05
if bias is not None:
out += bias
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
subB = None
if outlier_cols is not None and outlier_cols.numel():
# Extract the inputs with outliers in original precision
subA = A[:, outlier_cols].contiguous()
# Dequantize the corresponding weight columns
subB = (
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
.to(A.dtype)
.t()
)
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
else:
# Needed for torch.compile when there are no outliers.
subA = torch.empty(0, device=A.device, dtype=A.dtype)
# Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
if subB is not None:
# Add the outlier columns back to the output
output = output.addmm(subA, subB)
return output, subA
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
return torch.ops.bitsandbytes.int8_mm_dequant.default(
out_i32,
row_stats,
col_stats,
dtype=dtype or torch.float16,
bias=bias,
)
@register_kernel("bitsandbytes::int8_linear_matmul", "default")
def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B)
@register_kernel("bitsandbytes::int8_linear_matmul.out", "default")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out)
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
# Naive implementation: perform matmul in fp32
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result
@register_kernel("bitsandbytes::int8_vectorwise_quant", "default")
def _(A: torch.Tensor, threshold=0.0):
rows = prod(A.shape[:-1])
outlier_cols = None
outlier_restore = None
if threshold > 0.0:
outliers = A.abs() >= threshold
if outliers.any():
# Determine which columns contain outliers, and zero out the
# outliers ahead of quantization. We need to keep a backup of these
# outliers to restore them after quantization.
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
outlier_restore = A[outliers].clone()
A[outliers] = 0
else:
# Needed for torch.compile support.
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
# Get absmax for each row.
row_stats = torch.max(A.abs(), dim=1).values.float()
# Quantize row-wise to int8.
out_row = torch.round(A * (127.0 / row_stats.unsqueeze(-1))).to(torch.int8)
# Zero out values from outlier columns across all rows.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
# Restore outliers.
if outlier_restore is not None:
A[outliers] = outlier_restore
return out_row, row_stats, outlier_cols
@register_kernel("bitsandbytes::quantize_blockwise", "default")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
return out, absmax
@register_kernel("bitsandbytes::dequantize_blockwise", "default")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)
return out
@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)
n = A.numel()
full_blocks = n // blocksize
rem = n % blocksize
blocks = full_blocks + 1 if rem else full_blocks
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_flattened = A.reshape(n)
# Scale full blocks of the tensor to [-1, 1]
A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize)
absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0]
scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1)
# Scale any partial block
if rem:
A_rem = A_flattened[-rem:]
absmax[-1] = torch.abs(A_rem).max()
scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1)
scaled = torch.cat([scaled, scaled_rem], dim=0)
# Quantize with the lookup table
code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)
# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
return packed, absmax.float()
def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype).to(A.device)
out_dq = code[out_dq]
# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)
out = out.reshape(-1, *shape[1:]).to(dtype)
return out
@register_kernel("bitsandbytes::dequantize_4bit", "default")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
@register_kernel("bitsandbytes::gemv_4bit", "default")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# Applied from dequantize_4bit
quant_type = "fp4" if code[1] > 0 else "nf4"
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)
return torch.nn.functional.linear(
A,
B_dq,
bias=None,
)
MOMENTUM = 0
RMSPROP = 1
ADAGRAD = 2
ADAM = 3
# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
LION = 4
ADEMAMIX = 5
name2optimizer_id = {
"momentum": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}
@_try_torch_compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: torch.Tensor,
beta1: float,
beta2: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Preprocessing optimizer, computing update norm"""
g_vals = gnorm_scale * g
if optimizer_id == 3: # ADAM
correction1 = 1.0 / (1.0 - beta1**step)
correction2 = 1.0 / (1.0 - beta2**step)
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
s1_vals = s1_vals * correction1
s2_vals = s2_vals * correction2
update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
update_norm = update_vals * update_vals
elif optimizer_id == 5: # ADEMAMIX
update_norm = state1
elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals
update_norm = s1_vals * s1_vals
elif optimizer_id == 4: # LION
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
update_norm = s1_vals
elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
update_norm = update_vals * update_vals
total_norm = torch.sum(update_norm)
unorm_vec.add_(total_norm)
@_try_torch_compile
def _optimizer_update_32bit(
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
optimizer_id: int,
):
"""Unified optimizer update kernel"""
p_vals = p.float()
g_vals = (gnorm_scale * g).float()
if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
g_vals = g_vals + p_vals * weight_decay
update_scale = 1.0
if max_unorm > 0.0:
current_unorm = torch.sqrt(unorm_vec)
if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers
if current_unorm > max_unorm * param_norm + eps:
update_scale = (max_unorm * param_norm + eps) / current_unorm
else: # 2-state optimizers
if current_unorm > max_unorm * param_norm:
update_scale = (max_unorm * param_norm) / current_unorm
if optimizer_id == 3: # ADAM
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)
step_size = -lr * correction2 / correction1
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
p_vals = p_vals + update_val
state1.copy_(s1_vals)
state2.copy_(s2_vals)
elif optimizer_id == 5: # ADEMAMIX
s1_vals = state1[0]
s3_vals = state1[1]
s2_vals = state2
m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
correction1 = 1.0 - beta1**step
correction2 = sqrt(1.0 - beta2**step)
if weight_decay > 0.0:
p_vals = p_vals * (1.0 - lr * weight_decay)
mixed_momentum = (m1 / correction1) + (alpha * m2)
adaptive_term = (torch.sqrt(nu) / correction2) + eps
p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
state1[0].copy_(m1)
state1[1].copy_(m2)
state2.copy_(nu)
elif optimizer_id == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
else:
s1_vals = state1 * beta1 + g_vals
update_val = update_scale * (-lr * s1_vals)
p_vals = p_vals + update_val
state1.copy_(s1_vals)
elif optimizer_id == 4: # LION
momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
update_val = update_scale * lr * torch.sign(momentum_update)
p_vals = p_vals - update_val
s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
state1.copy_(s1_vals)
elif optimizer_id == 1: # RMSPROP
s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
state1.copy_(s1_vals)
elif optimizer_id == 2: # ADAGRAD
s1_vals = state1 + g_vals * g_vals
update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
p_vals = p_vals - update_val
state1.copy_(s1_vals)
p.copy_(p_vals)
@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
"""
32-bit optimizer implemented by PyTorch with @torch.compile
"""
if skip_zeros:
raise NotImplementedError("skip_zeros is not supported yet")
optimizer_id = name2optimizer_id[optimizer_name]
if optimizer_name == "lion":
_optimizer_update_32bit(
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
else:
if max_unorm > 0.0:
unorm_vec.zero_()
_optimizer_precondition_32bit(
g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
)
_optimizer_update_32bit(
g,
p,
state1,
state2,
unorm_vec,
max_unorm,
param_norm,
beta1,
beta2,
beta3,
alpha,
eps,
weight_decay,
step,
lr,
gnorm_scale,
optimizer_id,
)
|