File size: 4,607 Bytes
20347e1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | from typing import Optional, Tuple
import torch
from ._ops import ops
# Quant type constants (match bitsandbytes DataType_t)
FP4 = 1
NF4 = 2
def quantize_4bit(
input: torch.Tensor,
blocksize: int = 64,
quant_type: int = NF4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Blockwise 4-bit quantization using NF4 or FP4 codebook.
Args:
input: Input tensor on MPS device (float16, bfloat16, or float32).
blocksize: Number of elements per quantization block (64 or 128).
quant_type: FP4 (1) or NF4 (2).
Returns:
Tuple of (packed, absmax):
packed: uint8 tensor of packed 4-bit values [numel/2].
absmax: float32 tensor of per-block max absolute values.
"""
return ops.bnb_quantize_4bit(input, blocksize, quant_type)
def dequantize_4bit(
packed: torch.Tensor,
absmax: torch.Tensor,
blocksize: int = 64,
quant_type: int = NF4,
numel: int = -1,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""Blockwise 4-bit dequantization using NF4 or FP4 codebook.
Args:
packed: uint8 tensor of packed 4-bit values.
absmax: float32 tensor of per-block max absolute values.
blocksize: Number of elements per quantization block (64 or 128).
quant_type: FP4 (1) or NF4 (2).
numel: Number of elements in the original tensor.
If -1, inferred as packed.numel() * 2.
output_dtype: Output scalar type.
Returns:
Dequantized tensor.
"""
if numel < 0:
numel = packed.numel() * 2
return ops.bnb_dequantize_4bit(
packed, absmax, blocksize, quant_type, numel, output_dtype
)
def gemv_4bit(
x: torch.Tensor,
w: torch.Tensor,
absmax: torch.Tensor,
output_features: int,
blocksize: int = 64,
quant_type: int = NF4,
) -> torch.Tensor:
"""Fused matrix-vector multiply with 4-bit quantized weights.
Computes y = dequant(W) @ x, where W is blockwise NF4/FP4 quantized.
Args:
x: Input vector [..., K] on MPS device.
w: Packed weight matrix [N, K/2] (uint8) on MPS device.
absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
output_features: Number of output features (N).
blocksize: Quantization block size (64 or 128).
quant_type: FP4 (1) or NF4 (2).
Returns:
Output tensor [..., N].
"""
return ops.bnb_gemv_4bit(x, w, absmax, blocksize, quant_type, output_features)
def gemm_4bit(
x: torch.Tensor,
w: torch.Tensor,
absmax: torch.Tensor,
output_features: int,
blocksize: int = 64,
quant_type: int = NF4,
) -> torch.Tensor:
"""Fused matrix-matrix multiply with 4-bit quantized transposed weights.
Computes Y = X @ dequant(W).T, where W is blockwise NF4/FP4 quantized.
Args:
x: Input matrix [..., M, K] on MPS device.
w: Packed weight matrix [N, K/2] (uint8) on MPS device.
absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
output_features: Number of output features (N).
blocksize: Quantization block size (64 or 128).
quant_type: FP4 (1) or NF4 (2).
Returns:
Output tensor [..., M, N].
"""
return ops.bnb_gemm_4bit(x, w, absmax, blocksize, quant_type, output_features)
def linear_4bit(
x: torch.Tensor,
w: torch.Tensor,
absmax: torch.Tensor,
output_features: int,
blocksize: int = 64,
quant_type: int = NF4,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""4-bit quantized linear layer (auto-selects GEMV or GEMM).
Args:
x: Input tensor on MPS device.
w: Packed weight [N, K/2] (uint8).
absmax: Scales [N, ceil(K/blocksize)] (float32).
output_features: N.
blocksize: 64 or 128.
quant_type: FP4 (1) or NF4 (2).
bias: Optional bias [N].
Returns:
Output tensor.
"""
input_1d = x.dim() == 1
if input_1d or (x.dim() >= 2 and x.size(-2) == 1):
x_flat = x.view(x.size(-1)) if input_1d else x.squeeze(-2)
y = gemv_4bit(
x_flat,
w,
absmax,
output_features,
blocksize,
quant_type,
)
if input_1d:
y = y.squeeze(0)
elif x.dim() >= 2:
y = y.unsqueeze(-2)
else:
y = gemm_4bit(x, w, absmax, output_features, blocksize, quant_type)
if bias is not None:
y = y + bias
return y
__all__ = [
"quantize_4bit",
"dequantize_4bit",
"gemv_4bit",
"gemm_4bit",
"linear_4bit",
] |