File size: 2,832 Bytes
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# bitsandbytes-mps

Metal (MPS) kernels for bitsandbytes 4-bit quantization on Apple Silicon.

Provides NF4 and FP4 blockwise quantization, dequantization, and **fused GEMV/GEMM** operations for efficient inference with 4-bit quantized models on macOS.

## Operations

| Operation | Description |
|-----------|-------------|
| `quantize_4bit` | Blockwise 4-bit quantization (NF4/FP4) with per-block absmax |
| `dequantize_4bit` | Blockwise 4-bit dequantization using codebook lookup |
| `gemv_4bit` | Fused dequantize + matrix-vector multiply (batch_size=1 inference) |
| `gemm_4bit` | Fused dequantize + matrix-matrix multiply (larger batch inference) |
| `linear_4bit` | Auto-selecting linear layer (GEMV for vectors, GEMM for matrices) |

## Quantization Format

Uses the bitsandbytes blockwise quantization scheme:
- **Packing**: 2 values per byte (high nibble = first element, low nibble = second)
- **Scaling**: One `absmax` (float32) per block of `blocksize` elements
- **Codebook**: NF4 (16 values optimized for normal distributions) or FP4 (sign-magnitude floating point)
- **Dequantization**: `value = codebook[4bit_index] * absmax`

## Usage

```python
import torch
from bitsandbytes_mps import quantize_4bit, dequantize_4bit, gemv_4bit, gemm_4bit, NF4

# Quantize a weight matrix
weight = torch.randn(4096, 4096, dtype=torch.float16, device="mps")
packed, absmax = quantize_4bit(weight.flatten(), blocksize=64, quant_type=NF4)

# Dequantize
weight_deq = dequantize_4bit(packed, absmax, blocksize=64, quant_type=NF4,
                              numel=weight.numel(), output_dtype=torch.float16)

# Fused GEMV (single vector)
x = torch.randn(4096, dtype=torch.float16, device="mps")
packed_w = packed.view(4096, -1)  # [N, K/2]
absmax_w = absmax.view(4096, -1)  # [N, K_groups]
y = gemv_4bit(x, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)

# Fused GEMM (batch of vectors)
X = torch.randn(8, 4096, dtype=torch.float16, device="mps")
Y = gemm_4bit(X, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)
```

## Supported Configurations

- **Scalar types**: float16, bfloat16, float32
- **Block sizes**: 64, 128
- **Quant types**: FP4, NF4

## Architecture

The kernels are adapted from [MLX quantization Metal kernels](https://github.com/ml-explore/mlx) with the following modifications:

1. **Codebook-based dequantization** replaces MLX's affine `scale * q + bias` with `codebook[q] * absmax`
2. **BnB packing format**: high nibble first (vs MLX's low nibble first)
3. **`BnBQuantizedBlockLoader`**: Custom block loader for tiled GEMM that dequantizes on-the-fly using codebook lookup
4. **Binary search quantization**: Efficient NF4/FP4 quantization using decision trees (matching CUDA kernels)

## Building

```bash
pip install kernel-builder
kernel-builder build .
```