File size: 15,014 Bytes
dc9bb20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 | from collections.abc import Sequence
from math import prod
from typing import Optional
import torch
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
_IS_TORCH_GTE_24 = True
register_fake = torch.library.register_fake
register_kernel = torch.library.register_kernel
else:
# PyTorch <= 2.3
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl
# Int8 mixed precision matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_mixed_scaled_mm",
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_mixed_scaled_mm")
def _(
A: torch.Tensor,
CA: torch.Tensor,
CB: torch.Tensor,
SCA: torch.Tensor,
SCB: torch.Tensor,
outlier_cols: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
shapeC = (*CA.shape[:-1], CB.shape[0])
out = torch.empty(shapeC, device=A.device, dtype=A.dtype)
outlier_cols = torch.library.get_ctx().new_dynamic_size()
subA = A.new_empty(outlier_cols, dtype=torch.int64)
return out, subA
# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None) -> Tensor",
)
@register_fake("bitsandbytes::int8_scaled_mm")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype or torch.float16)
torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B) -> Tensor",
)
@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor):
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=torch.int32)
# More info on `out` overloads:
# https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul.out",
"(Tensor A, Tensor B, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::int8_linear_matmul.out")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
shapeC = (*A.shape[:-1], B.shape[0])
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
torch._check(B.dtype == torch.int8, lambda: "B must be int8")
torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
torch.library.define(
"bitsandbytes::int8_vectorwise_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
if threshold == 0.0:
return out_row, row_stats, None
outlier_cols = torch.library.get_ctx().new_dynamic_size()
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
@register_fake("bitsandbytes::int8_vectorwise_dequant")
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
return torch.empty_like(A, dtype=torch.float32)
# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", "default")
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3
torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None, Tensor? bias=None) -> Tensor",
)
@register_fake("bitsandbytes::int8_mm_dequant")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype or torch.float16)
torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)
@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
threshold=0.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
return out_row, out_col, row_stats, col_stats, outlier_cols
torch.library.define(
"bitsandbytes::dequantize_4bit",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)
torch.library.define(
"bitsandbytes::dequantize_4bit.out",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_4bit.out")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
)
@register_fake("bitsandbytes::quantize_4bit")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax
torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)
@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)
torch.library.define(
"bitsandbytes::dequantize_blockwise.out",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::dequantize_blockwise.out")
def _(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
):
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax
torch.library.define(
"bitsandbytes::gemv_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
)
@register_fake("bitsandbytes::gemv_4bit")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)
torch.library.define(
"bitsandbytes::gemv_4bit.out",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
)
@register_fake("bitsandbytes::gemv_4bit.out")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
)
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
torch.library.define(
"bitsandbytes::optimizer_update_32bit",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
)
@register_fake("bitsandbytes::optimizer_update_32bit")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
unorm_vec: Optional[torch.Tensor],
max_unorm: float,
param_norm: float,
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
weight_decay: float,
step: int,
lr: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
torch.library.define(
"bitsandbytes::optimizer_update_8bit_blockwise",
"(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
)
@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
def _(
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float,
gnorm_scale: float,
skip_zeros=False,
) -> None:
torch._check(
g.numel() == p.numel(),
lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
)
compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
torch._check(
g.dtype in compute_dtypes,
lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
)
torch._check(
g.dtype == p.dtype,
lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
torch._check(
state1.dtype == torch.uint8,
lambda: f"state1 must be uint8, got {state1.dtype}",
)
torch._check(
qmap1.dtype == absmax1.dtype == torch.float32,
lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
)
if state2 is not None:
torch._check(
state2.dtype == torch.uint8,
lambda: f"state2 must be uint8, got {state2.dtype}",
)
torch._check(
qmap2.dtype == absmax2.dtype == torch.float32,
lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
)
|