medmekk commited on
Commit
20347e1
·
verified ·
1 Parent(s): a24c715

Upload folder using huggingface_hub

Browse files
Files changed (38) hide show
  1. .gitattributes +2 -0
  2. .pytest_cache/.gitignore +2 -0
  3. .pytest_cache/CACHEDIR.TAG +4 -0
  4. .pytest_cache/README.md +8 -0
  5. .pytest_cache/v/cache/lastfailed +1 -0
  6. .pytest_cache/v/cache/nodeids +21 -0
  7. README.md +70 -0
  8. bitsandbytes_mps/bf16.h +29 -0
  9. bitsandbytes_mps/bf16_math.h +380 -0
  10. bitsandbytes_mps/bnb_quantized.h +541 -0
  11. bitsandbytes_mps/bnb_quantized.metal +48 -0
  12. bitsandbytes_mps/bnb_quantized.mm +382 -0
  13. bitsandbytes_mps/bnb_types.h +180 -0
  14. bitsandbytes_mps/complex.h +173 -0
  15. bitsandbytes_mps/defines.h +24 -0
  16. bitsandbytes_mps/gemm/defines.h +5 -0
  17. bitsandbytes_mps/gemm/gemm.h +295 -0
  18. bitsandbytes_mps/gemm/loader.h +137 -0
  19. bitsandbytes_mps/gemm/mma.h +735 -0
  20. bitsandbytes_mps/gemm/params.h +64 -0
  21. bitsandbytes_mps/gemm/transforms.h +72 -0
  22. bitsandbytes_mps/gemm/utils.h +42 -0
  23. bitsandbytes_mps/gemm/utils/integral_constant.h +134 -0
  24. bitsandbytes_mps/gemm/utils/type_traits.h +55 -0
  25. bitsandbytes_mps/quantized_utils.h +90 -0
  26. bitsandbytes_mps/utils.h +393 -0
  27. build.toml +49 -0
  28. build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so +3 -0
  29. build/torch210-metal-aarch64-darwin/_ops.py +3 -3
  30. build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so +3 -0
  31. build/torch29-metal-aarch64-darwin/_ops.py +3 -3
  32. flake.lock +95 -0
  33. flake.nix +17 -0
  34. tests/__pycache__/test_bnb_mps.cpython-312-pytest-8.4.2.pyc +0 -0
  35. tests/test_bnb_mps.py +256 -0
  36. torch-ext/bitsandbytes_mps/__init__.py +165 -0
  37. torch-ext/torch_binding.cpp +35 -0
  38. torch-ext/torch_binding.h +53 -0
.gitattributes CHANGED
@@ -39,3 +39,5 @@ build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filt
39
  build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
40
  torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
41
  torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
 
39
  build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
40
  torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
41
  torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
42
+ build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
43
+ build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
.pytest_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Created by pytest automatically.
2
+ *
.pytest_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
2
+ # This file is a cache directory tag created by pytest.
3
+ # For information about cache directory tags, see:
4
+ # https://bford.info/cachedir/spec.html
.pytest_cache/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # pytest cache directory #
2
+
3
+ This directory contains data from the pytest's cache plugin,
4
+ which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
5
+
6
+ **Do not** commit this to version control.
7
+
8
+ See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
.pytest_cache/v/cache/lastfailed ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
.pytest_cache/v/cache/nodeids ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "tests/test_bnb_mps.py::test_dequantize_matches_reference[128]",
3
+ "tests/test_bnb_mps.py::test_dequantize_matches_reference[64]",
4
+ "tests/test_bnb_mps.py::test_gemm_correctness[1-128]",
5
+ "tests/test_bnb_mps.py::test_gemm_correctness[1-64]",
6
+ "tests/test_bnb_mps.py::test_gemm_correctness[2-128]",
7
+ "tests/test_bnb_mps.py::test_gemm_correctness[2-64]",
8
+ "tests/test_bnb_mps.py::test_gemv_correctness[1-128]",
9
+ "tests/test_bnb_mps.py::test_gemv_correctness[1-64]",
10
+ "tests/test_bnb_mps.py::test_gemv_correctness[2-128]",
11
+ "tests/test_bnb_mps.py::test_gemv_correctness[2-64]",
12
+ "tests/test_bnb_mps.py::test_linear_4bit_auto_select",
13
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-1-128]",
14
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-1-64]",
15
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-2-128]",
16
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-2-64]",
17
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-1-128]",
18
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-1-64]",
19
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-2-128]",
20
+ "tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-2-64]"
21
+ ]
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bitsandbytes-mps
2
+
3
+ Metal (MPS) kernels for bitsandbytes 4-bit quantization on Apple Silicon.
4
+
5
+ Provides NF4 and FP4 blockwise quantization, dequantization, and **fused GEMV/GEMM** operations for efficient inference with 4-bit quantized models on macOS.
6
+
7
+ ## Operations
8
+
9
+ | Operation | Description |
10
+ |-----------|-------------|
11
+ | `quantize_4bit` | Blockwise 4-bit quantization (NF4/FP4) with per-block absmax |
12
+ | `dequantize_4bit` | Blockwise 4-bit dequantization using codebook lookup |
13
+ | `gemv_4bit` | Fused dequantize + matrix-vector multiply (batch_size=1 inference) |
14
+ | `gemm_4bit` | Fused dequantize + matrix-matrix multiply (larger batch inference) |
15
+ | `linear_4bit` | Auto-selecting linear layer (GEMV for vectors, GEMM for matrices) |
16
+
17
+ ## Quantization Format
18
+
19
+ Uses the bitsandbytes blockwise quantization scheme:
20
+ - **Packing**: 2 values per byte (high nibble = first element, low nibble = second)
21
+ - **Scaling**: One `absmax` (float32) per block of `blocksize` elements
22
+ - **Codebook**: NF4 (16 values optimized for normal distributions) or FP4 (sign-magnitude floating point)
23
+ - **Dequantization**: `value = codebook[4bit_index] * absmax`
24
+
25
+ ## Usage
26
+
27
+ ```python
28
+ import torch
29
+ from bitsandbytes_mps import quantize_4bit, dequantize_4bit, gemv_4bit, gemm_4bit, NF4
30
+
31
+ # Quantize a weight matrix
32
+ weight = torch.randn(4096, 4096, dtype=torch.float16, device="mps")
33
+ packed, absmax = quantize_4bit(weight.flatten(), blocksize=64, quant_type=NF4)
34
+
35
+ # Dequantize
36
+ weight_deq = dequantize_4bit(packed, absmax, blocksize=64, quant_type=NF4,
37
+ numel=weight.numel(), output_dtype=torch.float16)
38
+
39
+ # Fused GEMV (single vector)
40
+ x = torch.randn(4096, dtype=torch.float16, device="mps")
41
+ packed_w = packed.view(4096, -1) # [N, K/2]
42
+ absmax_w = absmax.view(4096, -1) # [N, K_groups]
43
+ y = gemv_4bit(x, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)
44
+
45
+ # Fused GEMM (batch of vectors)
46
+ X = torch.randn(8, 4096, dtype=torch.float16, device="mps")
47
+ Y = gemm_4bit(X, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)
48
+ ```
49
+
50
+ ## Supported Configurations
51
+
52
+ - **Scalar types**: float16, bfloat16, float32
53
+ - **Block sizes**: 64, 128
54
+ - **Quant types**: FP4, NF4
55
+
56
+ ## Architecture
57
+
58
+ The kernels are adapted from [MLX quantization Metal kernels](https://github.com/ml-explore/mlx) with the following modifications:
59
+
60
+ 1. **Codebook-based dequantization** replaces MLX's affine `scale * q + bias` with `codebook[q] * absmax`
61
+ 2. **BnB packing format**: high nibble first (vs MLX's low nibble first)
62
+ 3. **`BnBQuantizedBlockLoader`**: Custom block loader for tiled GEMM that dequantizes on-the-fly using codebook lookup
63
+ 4. **Binary search quantization**: Efficient NF4/FP4 quantization using decision trees (matching CUDA kernels)
64
+
65
+ ## Building
66
+
67
+ ```bash
68
+ pip install kernel-builder
69
+ kernel-builder build .
70
+ ```
bitsandbytes_mps/bf16.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ using namespace metal;
8
+
9
+ #if __METAL_VERSION__ >= 310
10
+ typedef bfloat bfloat16_t;
11
+ inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
12
+ return as_type<uint16_t>(x);
13
+ }
14
+
15
+ inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
16
+ return as_type<bfloat16_t>(x);
17
+ }
18
+ #else
19
+ // bfloat not available before Metal 3.1; use a stub so the file parses
20
+ // but only half/float kernels will be instantiated.
21
+ typedef half bfloat16_t;
22
+ inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
23
+ return as_type<uint16_t>(x);
24
+ }
25
+
26
+ inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
27
+ return as_type<bfloat16_t>(x);
28
+ }
29
+ #endif
bitsandbytes_mps/bf16_math.h ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // Metal math for bfloat16
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+
9
+ /*
10
+
11
+ Following the Metal Shading Language Specification (Metal 3.1)
12
+
13
+ "bfloat is an extended itypeing point type that only allows implicit conversion
14
+ to a type of greater itypeing point rank. While bfloat can be implicitly
15
+ converted to itype, it cannot be implicitly converted to half, and neither
16
+ itype nor half can be implicitly converted to bfloat."
17
+
18
+ Further, as far as I can tell, the stdlib math/simd functions are not defined
19
+ for bfloat and calling with an argument of type bfloat will result in that
20
+ argument getting implicitly converted to itype which then returns an output
21
+ that is (likely) a itype which cannot be implicitly converted into a bfloat
22
+
23
+ This leads to situations where
24
+ bfloat a = 5.0bf;
25
+ bfloat b = metal::abs(a); // this will throw an error since abs return itype
26
+ bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
27
+
28
+ For the moment, I will be adding overloaded instantiations of the math
29
+ functions to accordingly automatically handle the casting
30
+
31
+ */
32
+
33
+ #define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
34
+ \
35
+ METAL_FUNC otype abs(itype x) { \
36
+ return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
37
+ } \
38
+ METAL_FUNC otype acos(itype x) { \
39
+ return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
40
+ } \
41
+ METAL_FUNC otype acosh(itype x) { \
42
+ return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
43
+ } \
44
+ METAL_FUNC otype asin(itype x) { \
45
+ return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
46
+ } \
47
+ METAL_FUNC otype asinh(itype x) { \
48
+ return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
49
+ } \
50
+ METAL_FUNC otype atan(itype y_over_x) { \
51
+ return static_cast<otype>( \
52
+ __metal_atan(static_cast<ctype>(y_over_x), mfast)); \
53
+ } \
54
+ METAL_FUNC otype atan2(itype y, itype x) { \
55
+ return static_cast<otype>( \
56
+ __metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
57
+ } \
58
+ METAL_FUNC otype atanh(itype x) { \
59
+ return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
60
+ } \
61
+ METAL_FUNC otype ceil(itype x) { \
62
+ return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
63
+ } \
64
+ METAL_FUNC otype cos(itype x) { \
65
+ return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
66
+ } \
67
+ METAL_FUNC otype cosh(itype x) { \
68
+ return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
69
+ } \
70
+ METAL_FUNC otype cospi(itype x) { \
71
+ return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
72
+ } \
73
+ METAL_FUNC otype divide(itype x, itype y) { \
74
+ return static_cast<otype>( \
75
+ __metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
76
+ } \
77
+ METAL_FUNC otype exp(itype x) { \
78
+ return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
79
+ } \
80
+ METAL_FUNC otype exp10(itype x) { \
81
+ return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
82
+ } \
83
+ METAL_FUNC otype exp2(itype x) { \
84
+ return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
85
+ } \
86
+ METAL_FUNC otype fabs(itype x) { \
87
+ return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
88
+ } \
89
+ METAL_FUNC otype fdim(itype x, itype y) { \
90
+ ctype t = static_cast<ctype>(x - y); \
91
+ return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
92
+ } \
93
+ METAL_FUNC otype floor(itype x) { \
94
+ return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
95
+ } \
96
+ METAL_FUNC otype fma(itype x, itype y, itype z) { \
97
+ return static_cast<otype>(__metal_fma( \
98
+ static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
99
+ } \
100
+ METAL_FUNC otype fmax(itype x, itype y) { \
101
+ return static_cast<otype>( \
102
+ __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
103
+ } \
104
+ METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
105
+ return static_cast<otype>(__metal_fmax3( \
106
+ static_cast<ctype>(x), \
107
+ static_cast<ctype>(y), \
108
+ static_cast<ctype>(z), \
109
+ mfast)); \
110
+ } \
111
+ METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
112
+ return static_cast<otype>(__metal_fmedian3( \
113
+ static_cast<ctype>(x), \
114
+ static_cast<ctype>(y), \
115
+ static_cast<ctype>(z), \
116
+ mfast)); \
117
+ } \
118
+ METAL_FUNC otype fmin(itype x, itype y) { \
119
+ return static_cast<otype>( \
120
+ __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
121
+ } \
122
+ METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
123
+ return static_cast<otype>(__metal_fmin3( \
124
+ static_cast<ctype>(x), \
125
+ static_cast<ctype>(y), \
126
+ static_cast<ctype>(z), \
127
+ mfast)); \
128
+ } \
129
+ METAL_FUNC otype fmod(itype x, itype y) { \
130
+ return static_cast<otype>( \
131
+ __metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
132
+ } \
133
+ METAL_FUNC otype fract(itype x) { \
134
+ return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
135
+ } \
136
+ METAL_FUNC otype frexp(itype x, thread int& exp) { \
137
+ return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
138
+ } \
139
+ METAL_FUNC otype ldexp(itype x, int k) { \
140
+ return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
141
+ } \
142
+ METAL_FUNC otype log(itype x) { \
143
+ return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
144
+ } \
145
+ METAL_FUNC otype log10(itype x) { \
146
+ return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
147
+ } \
148
+ METAL_FUNC otype log2(itype x) { \
149
+ return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
150
+ } \
151
+ METAL_FUNC otype max(itype x, itype y) { \
152
+ return static_cast<otype>( \
153
+ __metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
154
+ } \
155
+ METAL_FUNC otype max3(itype x, itype y, itype z) { \
156
+ return static_cast<otype>(__metal_fmax3( \
157
+ static_cast<ctype>(x), \
158
+ static_cast<ctype>(y), \
159
+ static_cast<ctype>(z), \
160
+ mfast)); \
161
+ } \
162
+ METAL_FUNC otype median3(itype x, itype y, itype z) { \
163
+ return static_cast<otype>(__metal_fmedian3( \
164
+ static_cast<ctype>(x), \
165
+ static_cast<ctype>(y), \
166
+ static_cast<ctype>(z), \
167
+ mfast)); \
168
+ } \
169
+ METAL_FUNC otype min(itype x, itype y) { \
170
+ return static_cast<otype>( \
171
+ __metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
172
+ } \
173
+ METAL_FUNC otype min3(itype x, itype y, itype z) { \
174
+ return static_cast<otype>(__metal_fmin3( \
175
+ static_cast<ctype>(x), \
176
+ static_cast<ctype>(y), \
177
+ static_cast<ctype>(z), \
178
+ mfast)); \
179
+ } \
180
+ METAL_FUNC otype nextafter(itype x, itype y) { \
181
+ return static_cast<otype>( \
182
+ __metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
183
+ } \
184
+ METAL_FUNC otype pow(itype x, itype y) { \
185
+ return static_cast<otype>( \
186
+ __metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
187
+ } \
188
+ METAL_FUNC otype powr(itype x, itype y) { \
189
+ return static_cast<otype>( \
190
+ __metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
191
+ } \
192
+ METAL_FUNC otype rint(itype x) { \
193
+ return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
194
+ } \
195
+ METAL_FUNC otype round(itype x) { \
196
+ return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
197
+ } \
198
+ METAL_FUNC otype rsqrt(itype x) { \
199
+ return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
200
+ } \
201
+ METAL_FUNC otype sin(itype x) { \
202
+ return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
203
+ } \
204
+ METAL_FUNC otype sinh(itype x) { \
205
+ return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
206
+ } \
207
+ METAL_FUNC otype sinpi(itype x) { \
208
+ return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
209
+ } \
210
+ METAL_FUNC otype sqrt(itype x) { \
211
+ return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
212
+ } \
213
+ METAL_FUNC otype tan(itype x) { \
214
+ return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
215
+ } \
216
+ METAL_FUNC otype tanh(itype x) { \
217
+ return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
218
+ } \
219
+ METAL_FUNC otype tanpi(itype x) { \
220
+ return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
221
+ } \
222
+ METAL_FUNC otype trunc(itype x) { \
223
+ return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
224
+ }
225
+
226
+ namespace metal {
227
+
228
+ instantiate_metal_math_funcs(
229
+ bfloat16_t,
230
+ bfloat16_t,
231
+ float,
232
+ __METAL_MAYBE_FAST_MATH__);
233
+
234
+ namespace fast {
235
+
236
+ instantiate_metal_math_funcs(
237
+ bfloat16_t,
238
+ bfloat16_t,
239
+ float,
240
+ __METAL_FAST_MATH__);
241
+
242
+ } // namespace fast
243
+
244
+ namespace precise {
245
+
246
+ instantiate_metal_math_funcs(
247
+ bfloat16_t,
248
+ bfloat16_t,
249
+ float,
250
+ __METAL_PRECISE_MATH__);
251
+
252
+ } // namespace precise
253
+
254
+ } // namespace metal
255
+
256
+ ///////////////////////////////////////////////////////////////////////////////
257
+ // Metal simd for bfloat16
258
+ ///////////////////////////////////////////////////////////////////////////////
259
+
260
+ #define instantiate_metal_simd_comm_funcs( \
261
+ itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
262
+ \
263
+ METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
264
+ return ctype_to_otype( \
265
+ __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
266
+ } \
267
+ \
268
+ METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
269
+ return ctype_to_otype( \
270
+ __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
271
+ } \
272
+ \
273
+ METAL_FUNC otype simd_shuffle_and_fill_down( \
274
+ itype data, itype filling_data, ushort delta, ushort modulo) { \
275
+ return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
276
+ itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
277
+ } \
278
+ \
279
+ METAL_FUNC otype simd_shuffle_and_fill_down( \
280
+ itype data, itype filling_data, ushort delta) { \
281
+ return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
282
+ itype_to_ctype(data), \
283
+ itype_to_ctype(filling_data), \
284
+ delta, \
285
+ __metal_get_simdgroup_size(ushort()))); \
286
+ } \
287
+ \
288
+ METAL_FUNC otype simd_shuffle_and_fill_up( \
289
+ itype data, itype filling_data, ushort delta, ushort modulo) { \
290
+ return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
291
+ itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
292
+ } \
293
+ \
294
+ METAL_FUNC otype simd_shuffle_and_fill_up( \
295
+ itype data, itype filling_data, ushort delta) { \
296
+ return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
297
+ itype_to_ctype(data), \
298
+ itype_to_ctype(filling_data), \
299
+ delta, \
300
+ __metal_get_simdgroup_size(ushort()))); \
301
+ } \
302
+ \
303
+ METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
304
+ return ctype_to_otype( \
305
+ __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
306
+ } \
307
+ \
308
+ METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
309
+ return ctype_to_otype( \
310
+ __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
311
+ } \
312
+ \
313
+ METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
314
+ return ctype_to_otype( \
315
+ __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
316
+ } \
317
+ \
318
+ METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
319
+ return ctype_to_otype( \
320
+ __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
321
+ } \
322
+ \
323
+ METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
324
+ return ctype_to_otype( \
325
+ __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
326
+ }
327
+
328
+ #define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
329
+ \
330
+ METAL_FUNC otype simd_max(itype data) { \
331
+ return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
332
+ } \
333
+ \
334
+ METAL_FUNC otype simd_min(itype data) { \
335
+ return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
336
+ } \
337
+ \
338
+ METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
339
+ return static_cast<otype>( \
340
+ __metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
341
+ } \
342
+ \
343
+ METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
344
+ return static_cast<otype>( \
345
+ __metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
346
+ } \
347
+ \
348
+ METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
349
+ return static_cast<otype>( \
350
+ __metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
351
+ } \
352
+ \
353
+ METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
354
+ return static_cast<otype>( \
355
+ __metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
356
+ } \
357
+ \
358
+ METAL_FUNC otype simd_product(itype data) { \
359
+ return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
360
+ } \
361
+ \
362
+ METAL_FUNC otype simd_sum(itype data) { \
363
+ return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
364
+ } \
365
+ \
366
+ METAL_FUNC otype simd_xor(itype data) { \
367
+ return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
368
+ }
369
+
370
+ namespace metal {
371
+
372
+ instantiate_metal_simd_comm_funcs(
373
+ bfloat16_t,
374
+ bfloat16_t,
375
+ uint16_t,
376
+ bfloat16_to_uint16,
377
+ uint16_to_bfloat16);
378
+ instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
379
+
380
+ } // namespace metal
bitsandbytes_mps/bnb_quantized.h ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // bitsandbytes MPS Metal kernels - 4-bit quantized operations
2
+ // Adapted from MLX quantized.h for bitsandbytes NF4/FP4 format.
3
+ //
4
+ // Key differences from MLX affine quantization:
5
+ // MLX: dequant(q) = scale * q_int + bias (linear mapping)
6
+ // BnB: dequant(q) = codebook[q_int] * absmax (lookup-based)
7
+ //
8
+ // Packing format:
9
+ // BnB: high nibble = first element, low nibble = second element
10
+ // Two 4-bit values per byte, pack_factor = 2
11
+
12
+ #include <metal_simdgroup>
13
+ #include <metal_stdlib>
14
+
15
+ #include "bnb_types.h"
16
+
17
+ using namespace metal;
18
+
19
+ #define MLX_MTL_CONST static constant constexpr const
20
+
21
+ MLX_MTL_CONST int SIMD_SIZE = 32;
22
+
23
+ // ============================================================================
24
+ // BnBQuantizedBlockLoader
25
+ //
26
+ // Loads blocks of BnB 4-bit packed weights into threadgroup memory,
27
+ // performing codebook dequantization on the fly.
28
+ // Adapted from MLX QuantizedBlockLoader.
29
+ //
30
+ // Template parameters:
31
+ // T - output scalar type (float16_t, bfloat16_t, float)
32
+ // BROWS - number of rows in the tile
33
+ // BCOLS - number of columns in the tile (unpacked)
34
+ // dst_ld - leading dimension of destination (threadgroup memory)
35
+ // reduction_dim - 0 for K along rows, 1 for K along columns
36
+ // tgp_size - threads per threadgroup
37
+ // blocksize - BnB blocksize (elements per absmax value)
38
+ // quant_type - BNB_FP4 (1) or BNB_NF4 (2)
39
+ // ============================================================================
40
+
41
+ template <
42
+ typename T,
43
+ short BROWS,
44
+ short BCOLS,
45
+ short dst_ld,
46
+ short reduction_dim,
47
+ short tgp_size,
48
+ short blocksize,
49
+ int quant_type>
50
+ struct BnBQuantizedBlockLoader {
51
+ static_assert(
52
+ BCOLS <= blocksize,
53
+ "The blocksize should be larger than the tile columns");
54
+ static_assert(
55
+ blocksize % BCOLS == 0,
56
+ "The blocksize should be divisible by the tile columns");
57
+
58
+ MLX_MTL_CONST short pack_factor = 2;
59
+ MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
60
+ MLX_MTL_CONST short n_reads =
61
+ (BCOLS_PACKED * BROWS < tgp_size) ? 1
62
+ : (BCOLS_PACKED * BROWS) / tgp_size;
63
+ MLX_MTL_CONST short group_steps = blocksize / BCOLS;
64
+
65
+ const int src_ld;
66
+ const int tile_stride;
67
+ short group_step_cnt;
68
+ const int group_stride;
69
+
70
+ const short thread_idx;
71
+ const short bi;
72
+ const short bj;
73
+
74
+ threadgroup T* dst;
75
+ const device uint8_t* src;
76
+ const device float* absmax_ptr;
77
+
78
+ BnBQuantizedBlockLoader(
79
+ const device uint8_t* src_,
80
+ const device float* absmax_,
81
+ const int src_ld_,
82
+ threadgroup T* dst_,
83
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
84
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
85
+ : src_ld(src_ld_),
86
+ tile_stride(
87
+ reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
88
+ group_step_cnt(0),
89
+ group_stride(BROWS * src_ld / blocksize),
90
+ thread_idx(simd_group_id * 32 + simd_lane_id),
91
+ bi(n_reads * thread_idx / BCOLS_PACKED),
92
+ bj((n_reads * thread_idx) % BCOLS_PACKED),
93
+ dst(dst_ + bi * dst_ld + bj * pack_factor),
94
+ src(src_ + bi * src_ld / pack_factor + bj),
95
+ absmax_ptr(absmax_ + bi * src_ld / blocksize) {}
96
+
97
+ void load_unsafe() const {
98
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
99
+ return;
100
+ }
101
+
102
+ float am = *absmax_ptr;
103
+ for (int i = 0; i < n_reads; i++) {
104
+ bnb_dequantize<T, pack_factor, quant_type>(src + i, T(am), dst + i * pack_factor);
105
+ }
106
+ }
107
+
108
+ void load_safe(short2 src_tile_dim) const {
109
+ if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
110
+ return;
111
+ }
112
+
113
+ if (reduction_dim == 1 && bi >= src_tile_dim.x) {
114
+ for (int i = 0; i < n_reads * pack_factor; i++) {
115
+ dst[i] = T(0);
116
+ }
117
+ return;
118
+ }
119
+
120
+ if (reduction_dim == 0 && bi >= src_tile_dim.y) {
121
+ for (int i = 0; i < n_reads * pack_factor; i++) {
122
+ dst[i] = T(0);
123
+ }
124
+ return;
125
+ }
126
+
127
+ float am = *absmax_ptr;
128
+ for (int i = 0; i < n_reads; i++) {
129
+ bnb_dequantize<T, pack_factor, quant_type>(src + i, T(am), dst + i * pack_factor);
130
+ }
131
+ }
132
+
133
+ void next() {
134
+ src += tile_stride;
135
+ if (reduction_dim == 1) {
136
+ if (group_steps > 1) {
137
+ group_step_cnt++;
138
+ if (group_step_cnt == group_steps) {
139
+ group_step_cnt = 0;
140
+ absmax_ptr++;
141
+ }
142
+ } else {
143
+ absmax_ptr++;
144
+ }
145
+ } else {
146
+ absmax_ptr += group_stride;
147
+ }
148
+ }
149
+ };
150
+
151
+ // ============================================================================
152
+ // BnB GEMV (matrix-vector multiply with 4-bit quantized weights)
153
+ //
154
+ // Computes y = dequant(W) @ x
155
+ // W: [N, K/2] packed bytes, absmax: [N, ceil(K/blocksize)], x: [K], y: [N]
156
+ //
157
+ // Each simdgroup handles results_per_simdgroup output rows.
158
+ // Each thread processes values_per_thread elements of K per iteration.
159
+ // ============================================================================
160
+
161
+ template <typename T, int blocksize, int quant_type>
162
+ METAL_FUNC void bnb_qmv_impl(
163
+ const device uint8_t* w,
164
+ const device float* absmax,
165
+ const device T* x,
166
+ device T* y,
167
+ const constant int& in_vec_size,
168
+ const constant int& out_vec_size,
169
+ uint3 tid [[threadgroup_position_in_grid]],
170
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
171
+ uint simd_lid [[thread_index_in_simdgroup]]) {
172
+ constexpr int num_simdgroups = 2;
173
+ constexpr int results_per_simdgroup = 4;
174
+ constexpr int bytes_per_thread = 4;
175
+ constexpr int values_per_thread = bytes_per_thread * 2;
176
+ constexpr int block_size_k = values_per_thread * SIMD_SIZE;
177
+ constexpr int scale_step_per_thread = blocksize / values_per_thread;
178
+
179
+ constant float* codebook = bnb_codebook<quant_type>();
180
+
181
+ typedef float U;
182
+ thread U x_thread[values_per_thread];
183
+ thread U result[results_per_simdgroup] = {0};
184
+
185
+ const int K_packed = in_vec_size / 2;
186
+ const int K_groups = (in_vec_size + blocksize - 1) / blocksize;
187
+ const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
188
+ simd_gid * results_per_simdgroup;
189
+
190
+ if (out_row >= out_vec_size) {
191
+ return;
192
+ }
193
+
194
+ const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
195
+
196
+ const device uint8_t* ws =
197
+ w + used_out_row * K_packed + simd_lid * bytes_per_thread;
198
+ const device float* am =
199
+ absmax + used_out_row * K_groups + simd_lid / scale_step_per_thread;
200
+ const device T* xi = x + tid.x * in_vec_size + simd_lid * values_per_thread;
201
+ y += tid.x * out_vec_size + used_out_row;
202
+
203
+ int k = 0;
204
+ for (; k < in_vec_size - block_size_k; k += block_size_k) {
205
+ // Load x values
206
+ for (int i = 0; i < values_per_thread; i++) {
207
+ x_thread[i] = U(xi[i]);
208
+ }
209
+
210
+ // Compute dot product for each output row
211
+ for (int row = 0; row < results_per_simdgroup; row++) {
212
+ const device uint8_t* wl = ws + row * K_packed;
213
+ U scale = U(am[row * K_groups]);
214
+
215
+ U accum = 0;
216
+ for (int i = 0; i < bytes_per_thread; i++) {
217
+ uint8_t byte_val = wl[i];
218
+ U w0 = U(codebook[(byte_val >> 4) & 0x0f]);
219
+ U w1 = U(codebook[byte_val & 0x0f]);
220
+ accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1;
221
+ }
222
+ result[row] += accum * scale;
223
+ }
224
+
225
+ ws += block_size_k / 2;
226
+ am += block_size_k / blocksize;
227
+ xi += block_size_k;
228
+ }
229
+
230
+ // Handle remaining K elements
231
+ const int remaining = clamp(
232
+ static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
233
+ 0,
234
+ values_per_thread);
235
+ if (remaining > 0) {
236
+ for (int i = 0; i < remaining; i++) {
237
+ x_thread[i] = U(xi[i]);
238
+ }
239
+ for (int i = remaining; i < values_per_thread; i++) {
240
+ x_thread[i] = 0;
241
+ }
242
+
243
+ for (int row = 0; row < results_per_simdgroup; row++) {
244
+ const device uint8_t* wl = ws + row * K_packed;
245
+ U scale = U(am[row * K_groups]);
246
+
247
+ U accum = 0;
248
+ int bytes_to_read = (remaining + 1) / 2;
249
+ for (int i = 0; i < bytes_to_read; i++) {
250
+ uint8_t byte_val = wl[i];
251
+ U w0 = U(codebook[(byte_val >> 4) & 0x0f]);
252
+ U w1 = U(codebook[byte_val & 0x0f]);
253
+ accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1;
254
+ }
255
+ result[row] += accum * scale;
256
+ }
257
+ }
258
+
259
+ // Reduce across SIMD lanes
260
+ for (int row = 0; row < results_per_simdgroup; row++) {
261
+ result[row] = simd_sum(result[row]);
262
+ if (simd_lid == 0) {
263
+ y[row] = static_cast<T>(result[row]);
264
+ }
265
+ }
266
+ }
267
+
268
+ // ============================================================================
269
+ // BnB GEMM with transposed weight (y = x @ dequant(w).T)
270
+ //
271
+ // x: [M, K], w: [N, K/2] packed, absmax: [N, ceil(K/blocksize)], y: [M, N]
272
+ //
273
+ // Uses tiled matrix multiply with BnBQuantizedBlockLoader for on-the-fly
274
+ // dequantization of weights during the GEMM computation.
275
+ // ============================================================================
276
+
277
+ template <
278
+ typename T,
279
+ const int blocksize,
280
+ const int quant_type,
281
+ const int BM = 32,
282
+ const int BK = 32,
283
+ const int BN = 32>
284
+ METAL_FUNC void bnb_qmm_t_impl(
285
+ const device uint8_t* w,
286
+ const device float* absmax,
287
+ const device T* x,
288
+ device T* y,
289
+ threadgroup T* Xs,
290
+ threadgroup T* Ws,
291
+ const constant int& K,
292
+ const constant int& N,
293
+ const constant int& M,
294
+ uint3 tid [[threadgroup_position_in_grid]],
295
+ uint lid [[thread_index_in_threadgroup]],
296
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
297
+ uint simd_lid [[thread_index_in_simdgroup]]) {
298
+ static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
299
+ static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
300
+
301
+ (void)lid;
302
+
303
+ constexpr int WM = 2;
304
+ constexpr int WN = 2;
305
+ constexpr int pack_factor = 2;
306
+
307
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
308
+
309
+ using mma_t = mlx::steel::
310
+ BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
311
+ using loader_x_t =
312
+ mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
313
+ using loader_w_t = BnBQuantizedBlockLoader<
314
+ T,
315
+ BN,
316
+ BK,
317
+ BK_padded,
318
+ 1,
319
+ WM * WN * SIMD_SIZE,
320
+ blocksize,
321
+ quant_type>;
322
+
323
+ const int K_packed = K / pack_factor;
324
+ const int K_groups = (K + blocksize - 1) / blocksize;
325
+ const int y_row = tid.y * BM;
326
+ const int y_col = tid.x * BN;
327
+
328
+ x += y_row * static_cast<int64_t>(K);
329
+ w += y_col * K_packed;
330
+ absmax += y_col * K_groups;
331
+ y += y_row * static_cast<int64_t>(N) + y_col;
332
+
333
+ const short num_els = min(BM, M - y_row);
334
+ const short num_outs = min(BN, N - y_col);
335
+ loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
336
+ loader_w_t loader_w(
337
+ (const device uint8_t*)w, absmax, K, Ws, simd_gid, simd_lid);
338
+ mma_t mma_op(simd_gid, simd_lid);
339
+
340
+ if (num_els < BM) {
341
+ if (num_outs < BN) {
342
+ for (int k = 0; k < K; k += BK) {
343
+ threadgroup_barrier(mem_flags::mem_threadgroup);
344
+ loader_x.load_safe(short2(BK, num_els));
345
+ loader_w.load_safe(short2(BK, num_outs));
346
+ threadgroup_barrier(mem_flags::mem_threadgroup);
347
+ mma_op.mma(Xs, Ws);
348
+ loader_x.next();
349
+ loader_w.next();
350
+ }
351
+ } else {
352
+ for (int k = 0; k < K; k += BK) {
353
+ threadgroup_barrier(mem_flags::mem_threadgroup);
354
+ loader_x.load_safe(short2(BK, num_els));
355
+ loader_w.load_unsafe();
356
+ threadgroup_barrier(mem_flags::mem_threadgroup);
357
+ mma_op.mma(Xs, Ws);
358
+ loader_x.next();
359
+ loader_w.next();
360
+ }
361
+ }
362
+ } else {
363
+ if (num_outs < BN) {
364
+ for (int k = 0; k < K; k += BK) {
365
+ threadgroup_barrier(mem_flags::mem_threadgroup);
366
+ loader_x.load_unsafe();
367
+ loader_w.load_safe(short2(BK, num_outs));
368
+ threadgroup_barrier(mem_flags::mem_threadgroup);
369
+ mma_op.mma(Xs, Ws);
370
+ loader_x.next();
371
+ loader_w.next();
372
+ }
373
+ } else {
374
+ for (int k = 0; k < K; k += BK) {
375
+ threadgroup_barrier(mem_flags::mem_threadgroup);
376
+ loader_x.load_unsafe();
377
+ loader_w.load_unsafe();
378
+ threadgroup_barrier(mem_flags::mem_threadgroup);
379
+ mma_op.mma(Xs, Ws);
380
+ loader_x.next();
381
+ loader_w.next();
382
+ }
383
+ }
384
+ }
385
+
386
+ // Store results
387
+ threadgroup_barrier(mem_flags::mem_threadgroup);
388
+ if (num_els < BM || num_outs < BN) {
389
+ mma_op.store_result_safe(y, N, short2(num_outs, num_els));
390
+ } else {
391
+ mma_op.store_result(y, N);
392
+ }
393
+ }
394
+
395
+ // ============================================================================
396
+ // Kernel entry points
397
+ // ============================================================================
398
+
399
+ // ---- Standalone blockwise quantize ----
400
+ // Each thread handles one block of elements.
401
+
402
+ template <typename T, int blocksize, int quant_type>
403
+ [[kernel]] void bnb_quantize_blockwise(
404
+ const device T* input [[buffer(0)]],
405
+ device float* absmax [[buffer(1)]],
406
+ device uint8_t* packed [[buffer(2)]],
407
+ const constant int& n [[buffer(3)]],
408
+ uint gid [[thread_position_in_grid]]) {
409
+ const int num_blocks = (n + blocksize - 1) / blocksize;
410
+ if (static_cast<int>(gid) >= num_blocks) {
411
+ return;
412
+ }
413
+
414
+ int block_start = gid * blocksize;
415
+ int block_end = min(block_start + blocksize, n);
416
+
417
+ // Find absmax for this block
418
+ float max_val = 0.0f;
419
+ for (int i = block_start; i < block_end; i++) {
420
+ float current = metal::abs(float(input[i]));
421
+ max_val = metal::max(max_val, current);
422
+ }
423
+ absmax[gid] = max_val;
424
+
425
+ float inv = (max_val > 0.0f) ? 1.0f / max_val : 0.0f;
426
+
427
+ // Quantize and pack pairs of values
428
+ int out_byte = block_start / 2;
429
+ for (int i = block_start; i < block_end; i += 2) {
430
+ float norm0 = (max_val > 0.0f) ? clamp(float(input[i]) * inv, -1.0f, 1.0f)
431
+ : 0.0f;
432
+ uchar q0 = bnb_quantize_value<quant_type>(norm0);
433
+
434
+ uchar q1 = 0;
435
+ if (i + 1 < block_end) {
436
+ float norm1 = (max_val > 0.0f)
437
+ ? clamp(float(input[i + 1]) * inv, -1.0f, 1.0f)
438
+ : 0.0f;
439
+ q1 = bnb_quantize_value<quant_type>(norm1);
440
+ }
441
+
442
+ packed[out_byte++] = (q0 << 4) | (q1 & 0x0f);
443
+ }
444
+ }
445
+
446
+ // ---- Standalone blockwise dequantize ----
447
+ // Each threadgroup handles one block. Threads within share the absmax.
448
+
449
+ template <typename T, int blocksize, int quant_type>
450
+ [[kernel]] void bnb_dequantize_blockwise(
451
+ const device uint8_t* packed [[buffer(0)]],
452
+ const device float* absmax [[buffer(1)]],
453
+ device T* output [[buffer(2)]],
454
+ const constant int& n [[buffer(3)]],
455
+ uint tgid [[threadgroup_position_in_grid]],
456
+ uint tid [[thread_index_in_threadgroup]],
457
+ uint tg_size [[threads_per_threadgroup]]) {
458
+ const int num_blocks = (n + blocksize - 1) / blocksize;
459
+ if (static_cast<int>(tgid) >= num_blocks) {
460
+ return;
461
+ }
462
+
463
+ constant float* codebook = bnb_codebook<quant_type>();
464
+
465
+ int block_start = tgid * blocksize;
466
+ int block_end = min(block_start + blocksize, n);
467
+
468
+ threadgroup float shared_scale = 0.0f;
469
+ if (tid == 0) {
470
+ shared_scale = absmax[tgid];
471
+ }
472
+ threadgroup_barrier(mem_flags::mem_threadgroup);
473
+ float scale = shared_scale;
474
+
475
+ int pairs_in_block = (block_end - block_start + 1) / 2;
476
+
477
+ for (int pair = static_cast<int>(tid); pair < pairs_in_block;
478
+ pair += static_cast<int>(tg_size)) {
479
+ int elem_idx = block_start + pair * 2;
480
+ int byte_idx = elem_idx / 2;
481
+ uint8_t byte_val = packed[byte_idx];
482
+
483
+ uint8_t high = (byte_val >> 4) & 0x0f;
484
+ uint8_t low = byte_val & 0x0f;
485
+
486
+ output[elem_idx] = T(codebook[high] * scale);
487
+ if (elem_idx + 1 < block_end) {
488
+ output[elem_idx + 1] = T(codebook[low] * scale);
489
+ }
490
+ }
491
+ }
492
+
493
+ // ---- GEMV kernel entry point ----
494
+ // y = dequant(W) @ x
495
+ // W: [N, K/2], absmax: [N, K_groups], x: [K], y: [N]
496
+
497
+ template <typename T, int blocksize, int quant_type>
498
+ [[kernel]] void bnb_qmv(
499
+ const device uint8_t* w [[buffer(0)]],
500
+ const device float* absmax [[buffer(1)]],
501
+ const device T* x [[buffer(2)]],
502
+ device T* y [[buffer(3)]],
503
+ const constant int& in_vec_size [[buffer(4)]],
504
+ const constant int& out_vec_size [[buffer(5)]],
505
+ uint3 tid [[threadgroup_position_in_grid]],
506
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
507
+ uint simd_lid [[thread_index_in_simdgroup]]) {
508
+ bnb_qmv_impl<T, blocksize, quant_type>(
509
+ w, absmax, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
510
+ }
511
+
512
+ // ---- GEMM (transposed weight) kernel entry point ----
513
+ // Y = X @ dequant(W).T
514
+ // X: [M, K], W: [N, K/2], absmax: [N, K_groups], Y: [M, N]
515
+
516
+ template <typename T, int blocksize, int quant_type>
517
+ [[kernel]] void bnb_qmm_t(
518
+ const device uint8_t* w [[buffer(0)]],
519
+ const device float* absmax [[buffer(1)]],
520
+ const device T* x [[buffer(2)]],
521
+ device T* y [[buffer(3)]],
522
+ const constant int& K [[buffer(4)]],
523
+ const constant int& N [[buffer(5)]],
524
+ const constant int& M [[buffer(6)]],
525
+ uint3 tid [[threadgroup_position_in_grid]],
526
+ uint lid [[thread_index_in_threadgroup]],
527
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
528
+ uint simd_lid [[thread_index_in_simdgroup]]) {
529
+ (void)lid;
530
+
531
+ constexpr int BM = 32;
532
+ constexpr int BK = 32;
533
+ constexpr int BN = 32;
534
+ constexpr int BK_padded = (BK + 16 / sizeof(T));
535
+
536
+ threadgroup T Xs[BM * BK_padded];
537
+ threadgroup T Ws[BN * BK_padded];
538
+
539
+ bnb_qmm_t_impl<T, blocksize, quant_type, BM, BK, BN>(
540
+ w, absmax, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
541
+ }
bitsandbytes_mps/bnb_quantized.metal ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // bitsandbytes MPS Metal kernels - template instantiations
2
+ // Instantiates kernel variants for all (type, blocksize, quant_type) combos.
3
+
4
+ // clang-format off
5
+ #include "utils.h"
6
+ #include "gemm/gemm.h"
7
+ #include "quantized_utils.h"
8
+ #include "bnb_quantized.h"
9
+
10
+ // ============================================================================
11
+ // Instantiation macros
12
+ // ============================================================================
13
+
14
+ #define instantiate_bnb_kernel(name, type, blocksize, quant_type) \
15
+ template [[host_name( \
16
+ #name "_" #type "_bs_" #blocksize "_qt_" #quant_type \
17
+ )]] [[kernel]] decltype(name<type, blocksize, quant_type>) \
18
+ name<type, blocksize, quant_type>;
19
+
20
+ // ---- Instantiate all kernel types for a given (type, blocksize, quant_type) ----
21
+
22
+ #define instantiate_bnb_all_kernels(type, blocksize, quant_type) \
23
+ instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type) \
24
+ instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \
25
+ instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type) \
26
+ instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type)
27
+
28
+ // ---- Instantiate for all quant types (FP4=1, NF4=2) ----
29
+
30
+ #define instantiate_bnb_quant_types(type, blocksize) \
31
+ instantiate_bnb_all_kernels(type, blocksize, 1) \
32
+ instantiate_bnb_all_kernels(type, blocksize, 2)
33
+
34
+ // ---- Instantiate for all blocksizes ----
35
+
36
+ #define instantiate_bnb_blocksizes(type) \
37
+ instantiate_bnb_quant_types(type, 64) \
38
+ instantiate_bnb_quant_types(type, 128) \
39
+ instantiate_bnb_quant_types(type, 256) \
40
+ instantiate_bnb_quant_types(type, 512)
41
+
42
+ // ---- Instantiate for all scalar types ----
43
+
44
+ instantiate_bnb_blocksizes(half)
45
+ instantiate_bnb_blocksizes(bfloat16_t)
46
+ instantiate_bnb_blocksizes(float)
47
+
48
+ // clang-format on
bitsandbytes_mps/bnb_quantized.mm ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // bitsandbytes MPS Metal kernels - ObjC++ dispatch
2
+ // Interfaces between PyTorch MPS tensors and Metal compute kernels.
3
+ // Uses the same dispatch pattern as kernels-community/activation, with
4
+ // get_command_buffer() moved inside dispatch_sync to avoid race conditions
5
+ // during model loading.
6
+
7
+ #include <torch/torch.h>
8
+
9
+ #import <Foundation/Foundation.h>
10
+ #import <Metal/Metal.h>
11
+
12
+ #include <algorithm>
13
+ #include <iostream>
14
+ #include <sstream>
15
+ #include <unordered_map>
16
+
17
+ #ifdef EMBEDDED_METALLIB_HEADER
18
+ #include EMBEDDED_METALLIB_HEADER
19
+ #endif
20
+
21
+ // ============================================================================
22
+ // Metal helpers
23
+ // ============================================================================
24
+
25
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& t) {
26
+ return __builtin_bit_cast(id<MTLBuffer>, t.storage().data());
27
+ }
28
+
29
+ namespace {
30
+
31
+ static id<MTLLibrary> library = nil;
32
+
33
+ id<MTLLibrary> get_library() {
34
+ if (library != nil)
35
+ return library;
36
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
37
+ NSError* error = nil;
38
+
39
+ #ifdef EMBEDDED_METALLIB_HEADER
40
+ library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
41
+ if (library == nil) {
42
+ std::cerr << "Failed to create Metal library from embedded header"
43
+ << std::endl;
44
+ if (error)
45
+ std::cerr << "Error: " << [[error localizedDescription] UTF8String]
46
+ << std::endl;
47
+ }
48
+ #else
49
+ library = [device newDefaultLibrary];
50
+ if (library == nil) {
51
+ std::cerr << "Failed to load Metal library" << std::endl;
52
+ if (error)
53
+ std::cerr << "Error: " << [[error localizedDescription] UTF8String]
54
+ << std::endl;
55
+ }
56
+ #endif
57
+ return library;
58
+ }
59
+
60
+ id<MTLComputePipelineState> get_pipeline(const std::string& name) {
61
+ static std::unordered_map<std::string, id<MTLComputePipelineState>> cache;
62
+ auto it = cache.find(name);
63
+ if (it != cache.end())
64
+ return it->second;
65
+
66
+ id<MTLLibrary> lib = get_library();
67
+ if (!lib)
68
+ return nil;
69
+
70
+ id<MTLFunction> func =
71
+ [lib newFunctionWithName:[NSString stringWithUTF8String:name.c_str()]];
72
+ if (!func) {
73
+ std::cerr << "Kernel not found: " << name << std::endl;
74
+ return nil;
75
+ }
76
+
77
+ NSError* error = nil;
78
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
79
+ id<MTLComputePipelineState> state =
80
+ [device newComputePipelineStateWithFunction:func error:&error];
81
+ if (!state) {
82
+ std::cerr << "Failed to create pipeline for " << name << std::endl;
83
+ return nil;
84
+ }
85
+ cache[name] = state;
86
+ return state;
87
+ }
88
+
89
+ std::string type_str(torch::ScalarType type) {
90
+ switch (type) {
91
+ case torch::kFloat32:
92
+ return "float";
93
+ case torch::kFloat16:
94
+ return "half";
95
+ case torch::kBFloat16:
96
+ return "bfloat16_t";
97
+ default:
98
+ throw std::runtime_error("Unsupported dtype for BnB MPS kernels");
99
+ }
100
+ }
101
+
102
+ void set_tensor(
103
+ id<MTLComputeCommandEncoder> enc,
104
+ const torch::Tensor& t,
105
+ int index) {
106
+ [enc setBuffer:getMTLBufferStorage(t)
107
+ offset:t.storage_offset() * t.element_size()
108
+ atIndex:index];
109
+ }
110
+
111
+ } // namespace
112
+
113
+ // ============================================================================
114
+ // Public API: quantize_4bit
115
+ // ============================================================================
116
+
117
+ std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
118
+ at::Tensor input,
119
+ int64_t blocksize,
120
+ int64_t quant_type) {
121
+ TORCH_CHECK(input.is_mps(), "Input must be on MPS device");
122
+ TORCH_CHECK(
123
+ blocksize == 64 || blocksize == 128,
124
+ "Only blocksize 64 and 128 are supported");
125
+ TORCH_CHECK(
126
+ quant_type == 1 || quant_type == 2,
127
+ "quant_type must be 1 (FP4) or 2 (NF4)");
128
+
129
+ int n = static_cast<int>(input.numel());
130
+ int num_blocks =
131
+ (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);
132
+ int packed_size = (n + 1) / 2;
133
+
134
+ auto absmax =
135
+ torch::empty({num_blocks}, input.options().dtype(torch::kFloat32));
136
+ auto packed =
137
+ torch::empty({packed_size}, input.options().dtype(torch::kUInt8));
138
+
139
+ std::stringstream ss;
140
+ ss << "bnb_quantize_blockwise_" << type_str(input.scalar_type()) << "_bs_"
141
+ << blocksize << "_qt_" << quant_type;
142
+
143
+ auto pipeline = get_pipeline(ss.str());
144
+ TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
145
+
146
+ @autoreleasepool {
147
+ dispatch_sync(torch::mps::get_dispatch_queue(), ^{
148
+ @autoreleasepool {
149
+ id<MTLCommandBuffer> commandBuffer =
150
+ torch::mps::get_command_buffer();
151
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
152
+
153
+ id<MTLComputeCommandEncoder> encoder =
154
+ [commandBuffer computeCommandEncoder];
155
+ TORCH_CHECK(encoder, "Failed to create compute encoder");
156
+
157
+ [encoder setComputePipelineState:pipeline];
158
+
159
+ int idx = 0;
160
+ set_tensor(encoder, input, idx++);
161
+ set_tensor(encoder, absmax, idx++);
162
+ set_tensor(encoder, packed, idx++);
163
+ [encoder setBytes:&n length:sizeof(int) atIndex:idx++];
164
+
165
+ NSUInteger threads_per_tg = pipeline.threadExecutionWidth;
166
+ MTLSize grid = MTLSizeMake(num_blocks, 1, 1);
167
+ MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
168
+ [encoder dispatchThreads:grid threadsPerThreadgroup:tg];
169
+ [encoder endEncoding];
170
+
171
+ torch::mps::commit();
172
+ }
173
+ });
174
+ }
175
+
176
+ return std::make_tuple(packed, absmax);
177
+ }
178
+
179
+ // ============================================================================
180
+ // Public API: dequantize_blockwise
181
+ // ============================================================================
182
+
183
+ at::Tensor bnb_dequantize_4bit(
184
+ at::Tensor packed,
185
+ at::Tensor absmax,
186
+ int64_t blocksize,
187
+ int64_t quant_type,
188
+ int64_t numel,
189
+ torch::ScalarType output_dtype) {
190
+ TORCH_CHECK(packed.is_mps(), "packed must be on MPS device");
191
+ TORCH_CHECK(absmax.is_mps(), "absmax must be on MPS device");
192
+ TORCH_CHECK(
193
+ blocksize == 64 || blocksize == 128,
194
+ "Only blocksize 64 and 128 are supported");
195
+
196
+ int n = static_cast<int>(numel);
197
+ int num_blocks =
198
+ (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);
199
+
200
+ auto output = torch::empty({n}, packed.options().dtype(output_dtype));
201
+
202
+ std::stringstream ss;
203
+ ss << "bnb_dequantize_blockwise_" << type_str(output_dtype) << "_bs_"
204
+ << blocksize << "_qt_" << quant_type;
205
+
206
+ auto pipeline = get_pipeline(ss.str());
207
+ TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
208
+
209
+ @autoreleasepool {
210
+ dispatch_sync(torch::mps::get_dispatch_queue(), ^{
211
+ @autoreleasepool {
212
+ id<MTLCommandBuffer> commandBuffer =
213
+ torch::mps::get_command_buffer();
214
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
215
+
216
+ id<MTLComputeCommandEncoder> encoder =
217
+ [commandBuffer computeCommandEncoder];
218
+ TORCH_CHECK(encoder, "Failed to create compute encoder");
219
+
220
+ [encoder setComputePipelineState:pipeline];
221
+
222
+ int idx = 0;
223
+ set_tensor(encoder, packed, idx++);
224
+ set_tensor(encoder, absmax, idx++);
225
+ set_tensor(encoder, output, idx++);
226
+ [encoder setBytes:&n length:sizeof(int) atIndex:idx++];
227
+
228
+ NSUInteger max_tg = pipeline.maxTotalThreadsPerThreadgroup;
229
+ NSUInteger desired = (blocksize + 1) / 2;
230
+ NSUInteger tg_size =
231
+ std::min(max_tg, std::max(static_cast<NSUInteger>(1), desired));
232
+ if (tg_size < pipeline.threadExecutionWidth) {
233
+ tg_size = std::min(pipeline.threadExecutionWidth, max_tg);
234
+ }
235
+
236
+ MTLSize grid = MTLSizeMake(tg_size * num_blocks, 1, 1);
237
+ MTLSize tg = MTLSizeMake(tg_size, 1, 1);
238
+ [encoder dispatchThreads:grid threadsPerThreadgroup:tg];
239
+ [encoder endEncoding];
240
+
241
+ torch::mps::commit();
242
+ }
243
+ });
244
+ }
245
+
246
+ return output;
247
+ }
248
+
249
+ // ============================================================================
250
+ // Public API: GEMV (matrix-vector multiply)
251
+ // y = dequant(W) @ x
252
+ // ============================================================================
253
+
254
+ at::Tensor bnb_gemv_4bit(
255
+ at::Tensor x,
256
+ at::Tensor w,
257
+ at::Tensor absmax,
258
+ int64_t blocksize,
259
+ int64_t quant_type,
260
+ int64_t output_features) {
261
+ TORCH_CHECK(
262
+ x.is_mps() && w.is_mps() && absmax.is_mps(),
263
+ "All tensors must be on MPS device");
264
+
265
+ int K = static_cast<int>(x.size(-1));
266
+ int N = static_cast<int>(output_features);
267
+
268
+ auto out_sizes = x.sizes().vec();
269
+ out_sizes.back() = N;
270
+ auto y = torch::zeros(out_sizes, x.options());
271
+
272
+ std::stringstream ss;
273
+ ss << "bnb_qmv_" << type_str(x.scalar_type()) << "_bs_" << blocksize
274
+ << "_qt_" << quant_type;
275
+
276
+ auto pipeline = get_pipeline(ss.str());
277
+ TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
278
+
279
+ @autoreleasepool {
280
+ dispatch_sync(torch::mps::get_dispatch_queue(), ^{
281
+ @autoreleasepool {
282
+ id<MTLCommandBuffer> commandBuffer =
283
+ torch::mps::get_command_buffer();
284
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
285
+
286
+ id<MTLComputeCommandEncoder> encoder =
287
+ [commandBuffer computeCommandEncoder];
288
+ TORCH_CHECK(encoder, "Failed to create compute encoder");
289
+
290
+ [encoder setComputePipelineState:pipeline];
291
+
292
+ int idx = 0;
293
+ set_tensor(encoder, w, idx++);
294
+ set_tensor(encoder, absmax, idx++);
295
+ set_tensor(encoder, x, idx++);
296
+ set_tensor(encoder, y, idx++);
297
+ [encoder setBytes:&K length:sizeof(int) atIndex:idx++];
298
+ [encoder setBytes:&N length:sizeof(int) atIndex:idx++];
299
+
300
+ int rows_per_tg = 8;
301
+ int grid_y = (N + rows_per_tg - 1) / rows_per_tg;
302
+
303
+ [encoder dispatchThreadgroups:MTLSizeMake(1, grid_y, 1)
304
+ threadsPerThreadgroup:MTLSizeMake(32 * 2, 1, 1)];
305
+ [encoder endEncoding];
306
+
307
+ torch::mps::commit();
308
+ }
309
+ });
310
+ }
311
+
312
+ return y;
313
+ }
314
+
315
+ // ============================================================================
316
+ // Public API: GEMM (matrix-matrix multiply with transposed weight)
317
+ // Y = X @ dequant(W).T
318
+ // ============================================================================
319
+
320
+ at::Tensor bnb_gemm_4bit(
321
+ at::Tensor x,
322
+ at::Tensor w,
323
+ at::Tensor absmax,
324
+ int64_t blocksize,
325
+ int64_t quant_type,
326
+ int64_t output_features) {
327
+ TORCH_CHECK(
328
+ x.is_mps() && w.is_mps() && absmax.is_mps(),
329
+ "All tensors must be on MPS device");
330
+ TORCH_CHECK(x.dim() >= 2, "Input must be at least 2D for GEMM");
331
+
332
+ int K = static_cast<int>(x.size(-1));
333
+ int M = static_cast<int>(x.size(-2));
334
+ int N = static_cast<int>(output_features);
335
+
336
+ auto out_sizes = x.sizes().vec();
337
+ out_sizes.back() = N;
338
+ auto y = torch::zeros(out_sizes, x.options());
339
+
340
+ std::stringstream ss;
341
+ ss << "bnb_qmm_t_" << type_str(x.scalar_type()) << "_bs_" << blocksize
342
+ << "_qt_" << quant_type;
343
+
344
+ auto pipeline = get_pipeline(ss.str());
345
+ TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
346
+
347
+ @autoreleasepool {
348
+ dispatch_sync(torch::mps::get_dispatch_queue(), ^{
349
+ @autoreleasepool {
350
+ id<MTLCommandBuffer> commandBuffer =
351
+ torch::mps::get_command_buffer();
352
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
353
+
354
+ id<MTLComputeCommandEncoder> encoder =
355
+ [commandBuffer computeCommandEncoder];
356
+ TORCH_CHECK(encoder, "Failed to create compute encoder");
357
+
358
+ [encoder setComputePipelineState:pipeline];
359
+
360
+ int idx = 0;
361
+ set_tensor(encoder, w, idx++);
362
+ set_tensor(encoder, absmax, idx++);
363
+ set_tensor(encoder, x, idx++);
364
+ set_tensor(encoder, y, idx++);
365
+ [encoder setBytes:&K length:sizeof(int) atIndex:idx++];
366
+ [encoder setBytes:&N length:sizeof(int) atIndex:idx++];
367
+ [encoder setBytes:&M length:sizeof(int) atIndex:idx++];
368
+
369
+ int grid_x = (N + 31) / 32;
370
+ int grid_y = (M + 31) / 32;
371
+
372
+ [encoder dispatchThreadgroups:MTLSizeMake(grid_x, grid_y, 1)
373
+ threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
374
+ [encoder endEncoding];
375
+
376
+ torch::mps::commit();
377
+ }
378
+ });
379
+ }
380
+
381
+ return y;
382
+ }
bitsandbytes_mps/bnb_types.h ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // bitsandbytes MPS Metal kernels - NF4/FP4 codebook definitions and helpers
2
+ // Adapted from bitsandbytes CUDA kernels (kernels.cu) for Apple Metal
3
+
4
+ #pragma once
5
+
6
+ #include <metal_stdlib>
7
+ using namespace metal;
8
+
9
+ // ============================================================================
10
+ // Quant type enum (matches bitsandbytes common.h)
11
+ // ============================================================================
12
+
13
+ enum BnBQuantType {
14
+ BNB_FP4 = 1,
15
+ BNB_NF4 = 2,
16
+ };
17
+
18
+ // ============================================================================
19
+ // NF4 codebook - 16 values optimized for normal distributions
20
+ // Maps 4-bit indices (0-15) to float values in [-1, 1]
21
+ // ============================================================================
22
+
23
+ constant float NF4_CODEBOOK[16] = {
24
+ -1.0f,
25
+ -0.6961928009986877f,
26
+ -0.5250730514526367f,
27
+ -0.39491748809814453f,
28
+ -0.28444138169288635f,
29
+ -0.18477343022823334f,
30
+ -0.09105003625154495f,
31
+ 0.0f,
32
+ 0.07958029955625534f,
33
+ 0.16093020141124725f,
34
+ 0.24611230194568634f,
35
+ 0.33791524171829224f,
36
+ 0.44070982933044434f,
37
+ 0.5626170039176941f,
38
+ 0.7229568362236023f,
39
+ 1.0f,
40
+ };
41
+
42
+ // ============================================================================
43
+ // FP4 codebook - 16 values using sign-magnitude FP4 encoding
44
+ // Indices 0-7: non-negative, indices 8-15: negative (bit 3 = sign)
45
+ // ============================================================================
46
+
47
+ constant float FP4_CODEBOOK[16] = {
48
+ 0.0f,
49
+ 0.005208333333f,
50
+ 0.66666667f,
51
+ 1.0f,
52
+ 0.33333333f,
53
+ 0.5f,
54
+ 0.16666667f,
55
+ 0.25f,
56
+ 0.0f,
57
+ -0.005208333333f,
58
+ -0.66666667f,
59
+ -1.0f,
60
+ -0.33333333f,
61
+ -0.5f,
62
+ -0.16666667f,
63
+ -0.25f,
64
+ };
65
+
66
+ // ============================================================================
67
+ // Codebook accessor by quant_type template parameter
68
+ // ============================================================================
69
+
70
+ template <int quant_type>
71
+ inline constant float* bnb_codebook() {
72
+ if (quant_type == BNB_NF4) {
73
+ return NF4_CODEBOOK;
74
+ } else {
75
+ return FP4_CODEBOOK;
76
+ }
77
+ }
78
+
79
+ // ============================================================================
80
+ // NF4 quantization - binary search (matches CUDA dQuantizeNF4)
81
+ // Input: normalized value in [-1, 1]
82
+ // Output: 4-bit index (0-15)
83
+ // ============================================================================
84
+
85
+ inline uchar quantize_nf4(float x) {
86
+ if (x > 0.03979014977812767f) {
87
+ if (x > 0.3893125355243683f) {
88
+ if (x > 0.6427869200706482f) {
89
+ return (x > 0.8614784181118011f) ? 15 : 14;
90
+ }
91
+ return (x > 0.5016634166240692f) ? 13 : 12;
92
+ }
93
+ if (x > 0.2035212516784668f) {
94
+ return (x > 0.2920137718319893f) ? 11 : 10;
95
+ }
96
+ return (x > 0.1202552504837513f) ? 9 : 8;
97
+ }
98
+ if (x > -0.33967943489551544f) {
99
+ if (x > -0.13791173323988914f) {
100
+ return (x > -0.045525018125772476f) ? 7 : 6;
101
+ }
102
+ return (x > -0.23460740596055984f) ? 5 : 4;
103
+ }
104
+ if (x > -0.6106329262256622f) {
105
+ return (x > -0.4599952697753906f) ? 3 : 2;
106
+ }
107
+ return (x > -0.8480964004993439f) ? 1 : 0;
108
+ }
109
+
110
+ // ============================================================================
111
+ // FP4 quantization - binary search (matches CUDA dQuantizeFP4)
112
+ // Input: normalized value in [-1, 1]
113
+ // Output: 4-bit index (0-15), MSB = sign bit
114
+ // ============================================================================
115
+
116
+ inline uchar quantize_fp4(float x) {
117
+ uchar sign = (x < 0.0f) ? 8 : 0;
118
+ x = metal::abs(x);
119
+ uchar code;
120
+ if (x > 0.29166667f) {
121
+ if (x > 0.75f) {
122
+ code = (x > 0.8333333f) ? 3 : 2;
123
+ } else {
124
+ code = (x > 0.4166667f) ? 5 : 4;
125
+ }
126
+ } else {
127
+ if (x > 0.0859375f) {
128
+ code = (x > 0.20833333f) ? 7 : 6;
129
+ } else {
130
+ code = (x > 0.00260416f) ? 1 : 0;
131
+ }
132
+ }
133
+ return sign | code;
134
+ }
135
+
136
+ // ============================================================================
137
+ // Generic quantize dispatch by quant_type
138
+ // ============================================================================
139
+
140
+ template <int quant_type>
141
+ inline uchar bnb_quantize_value(float normalized) {
142
+ if (quant_type == BNB_NF4) {
143
+ return quantize_nf4(normalized);
144
+ } else {
145
+ return quantize_fp4(normalized);
146
+ }
147
+ }
148
+
149
+ // ============================================================================
150
+ // Dequantize a single 4-bit value using codebook lookup
151
+ // ============================================================================
152
+
153
+ template <int quant_type>
154
+ inline float bnb_dequantize_value(uchar nibble) {
155
+ return bnb_codebook<quant_type>()[nibble & 0x0f];
156
+ }
157
+
158
+ // ============================================================================
159
+ // BnB 4-bit dequantize for block loader (adapted from MLX affine dequantize)
160
+ // Unpacks N values from packed bytes using codebook lookup.
161
+ //
162
+ // BnB packing: high nibble = first element, low nibble = second element
163
+ // Each byte stores 2 4-bit values.
164
+ // ============================================================================
165
+
166
+ template <typename U, int N, int quant_type>
167
+ inline void bnb_dequantize(
168
+ const device uint8_t* w,
169
+ U absmax_val,
170
+ threadgroup U* w_local) {
171
+ constant float* codebook = bnb_codebook<quant_type>();
172
+
173
+ for (int i = 0; i < N / 2; i++) {
174
+ uint8_t byte_val = w[i];
175
+ uint8_t high = (byte_val >> 4) & 0x0f;
176
+ uint8_t low = byte_val & 0x0f;
177
+ w_local[2 * i] = U(codebook[high]) * absmax_val;
178
+ w_local[2 * i + 1] = U(codebook[low]) * absmax_val;
179
+ }
180
+ }
bitsandbytes_mps/complex.h ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ using namespace metal;
8
+
9
+ struct complex64_t;
10
+
11
+ template <typename T>
12
+ static constexpr constant bool can_convert_to_complex64 =
13
+ !is_same_v<T, complex64_t> && is_convertible_v<T, float>;
14
+
15
+ template <typename T>
16
+ static constexpr constant bool can_convert_from_complex64 =
17
+ !is_same_v<T, complex64_t> &&
18
+ (is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
19
+
20
+ struct complex64_t {
21
+ float real;
22
+ float imag;
23
+
24
+ // Constructors
25
+ constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
26
+ constexpr complex64_t() : real(0), imag(0) {};
27
+ constexpr complex64_t() threadgroup : real(0), imag(0) {};
28
+
29
+ // Conversions to complex64_t
30
+ template <
31
+ typename T,
32
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
33
+ constexpr complex64_t(T x) thread : real(x), imag(0) {}
34
+
35
+ template <
36
+ typename T,
37
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
38
+ constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
39
+
40
+ template <
41
+ typename T,
42
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
43
+ constexpr complex64_t(T x) device : real(x), imag(0) {}
44
+
45
+ template <
46
+ typename T,
47
+ typename = typename enable_if<can_convert_to_complex64<T>>::type>
48
+ constexpr complex64_t(T x) constant : real(x), imag(0) {}
49
+
50
+ // Conversions from complex64_t
51
+ template <
52
+ typename T,
53
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
54
+ constexpr operator T() const thread {
55
+ return static_cast<T>(real);
56
+ }
57
+
58
+ template <
59
+ typename T,
60
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
61
+ constexpr operator T() const threadgroup {
62
+ return static_cast<T>(real);
63
+ }
64
+
65
+ template <
66
+ typename T,
67
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
68
+ constexpr operator T() const device {
69
+ return static_cast<T>(real);
70
+ }
71
+
72
+ template <
73
+ typename T,
74
+ typename = typename enable_if<can_convert_from_complex64<T>>::type>
75
+ constexpr operator T() const constant {
76
+ return static_cast<T>(real);
77
+ }
78
+ };
79
+
80
+ constexpr complex64_t operator-(complex64_t x) {
81
+ return {-x.real, -x.imag};
82
+ }
83
+
84
+ constexpr bool operator>=(complex64_t a, complex64_t b) {
85
+ return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
86
+ }
87
+
88
+ constexpr bool operator>(complex64_t a, complex64_t b) {
89
+ return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
90
+ }
91
+
92
+ constexpr bool operator<=(complex64_t a, complex64_t b) {
93
+ return operator>=(b, a);
94
+ }
95
+
96
+ constexpr bool operator<(complex64_t a, complex64_t b) {
97
+ return operator>(b, a);
98
+ }
99
+
100
+ constexpr bool operator==(complex64_t a, complex64_t b) {
101
+ return a.real == b.real && a.imag == b.imag;
102
+ }
103
+
104
+ constexpr complex64_t operator+(complex64_t a, complex64_t b) {
105
+ return {a.real + b.real, a.imag + b.imag};
106
+ }
107
+
108
+ constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
109
+ a.real += b.real;
110
+ a.imag += b.imag;
111
+ return a;
112
+ }
113
+
114
+ constexpr threadgroup complex64_t& operator+=(
115
+ threadgroup complex64_t& a,
116
+ complex64_t b) {
117
+ a.real += b.real;
118
+ a.imag += b.imag;
119
+ return a;
120
+ }
121
+
122
+ constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
123
+ a.real += b.real;
124
+ a.imag += b.imag;
125
+ return a;
126
+ }
127
+
128
+ constexpr complex64_t operator+(float a, complex64_t b) {
129
+ return {a + b.real, b.imag};
130
+ }
131
+ constexpr complex64_t operator+(complex64_t a, float b) {
132
+ return {a.real + b, a.imag};
133
+ }
134
+
135
+ constexpr complex64_t operator-(complex64_t a, complex64_t b) {
136
+ return {a.real - b.real, a.imag - b.imag};
137
+ }
138
+ constexpr complex64_t operator-(float a, complex64_t b) {
139
+ return {a - b.real, -b.imag};
140
+ }
141
+ constexpr complex64_t operator-(complex64_t a, float b) {
142
+ return {a.real - b, a.imag};
143
+ }
144
+
145
+ constexpr complex64_t operator*(complex64_t a, complex64_t b) {
146
+ return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
147
+ }
148
+
149
+ constexpr complex64_t operator/(complex64_t a, complex64_t b) {
150
+ auto denom = b.real * b.real + b.imag * b.imag;
151
+ auto x = a.real * b.real + a.imag * b.imag;
152
+ auto y = a.imag * b.real - a.real * b.imag;
153
+ return {x / denom, y / denom};
154
+ }
155
+
156
+ constexpr complex64_t operator/(float a, complex64_t b) {
157
+ auto denom = b.real * b.real + b.imag * b.imag;
158
+ auto x = a * b.real;
159
+ auto y = -a * b.imag;
160
+ return {x / denom, y / denom};
161
+ }
162
+
163
+ constexpr complex64_t operator%(complex64_t a, complex64_t b) {
164
+ auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
165
+ auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
166
+ if (real != 0 && (real < 0 != b.real < 0)) {
167
+ real += b.real;
168
+ }
169
+ if (imag != 0 && (imag < 0 != b.imag < 0)) {
170
+ imag += b.imag;
171
+ }
172
+ return {real, imag};
173
+ }
bitsandbytes_mps/defines.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #if defined __METAL__ || defined MLX_METAL_JIT
6
+ #define MTL_CONST constant
7
+ #else
8
+ #define MTL_CONST
9
+ #endif
10
+
11
+ static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
12
+ static MTL_CONST constexpr int REDUCE_N_READS = 4;
13
+ static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
14
+ static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
15
+ static MTL_CONST constexpr int RMS_N_READS = 4;
16
+ static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
17
+
18
+ // Instantiate a templated kernel.
19
+ // Extra args are used as template parameters:
20
+ // e.g. instantiate_kernel(binary_int, binary, a, b) ->
21
+ // [[host_name(binary_int)]] [kernel] binary<a, b>
22
+ #define instantiate_kernel(name, func, ...) \
23
+ template [[host_name( \
24
+ name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
bitsandbytes_mps/gemm/defines.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define STEEL_CONST static constant constexpr const
4
+ #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
5
+ #define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
bitsandbytes_mps/gemm/gemm.h ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "gemm/loader.h"
6
+ #include "gemm/mma.h"
7
+ #include "gemm/params.h"
8
+ #include "gemm/transforms.h"
9
+ #include "gemm/utils.h"
10
+
11
+ using namespace metal;
12
+
13
+ ///////////////////////////////////////////////////////////////////////////////
14
+ // GEMM kernel class
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+
17
+ namespace mlx {
18
+ namespace steel {
19
+
20
+ template <bool M_aligned, bool N_aligned, bool K_aligned>
21
+ struct LoopAlignment {};
22
+
23
+ template <
24
+ typename T,
25
+ typename U,
26
+ int BM,
27
+ int BN,
28
+ int BK,
29
+ int WM,
30
+ int WN,
31
+ bool transpose_a,
32
+ bool transpose_b,
33
+ bool MN_aligned,
34
+ bool K_aligned,
35
+ typename AccumType = typename AccumHelper<T>::accum_type,
36
+ typename Epilogue = TransformNone<U, AccumType>>
37
+ struct GEMMKernel {
38
+ STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
39
+ STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
40
+ STEEL_CONST short tgp_mem_size_a =
41
+ transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
42
+ STEEL_CONST short tgp_mem_size_b =
43
+ transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
44
+ STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
45
+
46
+ STEEL_CONST short tgp_size = WM * WN * 32;
47
+
48
+ using loader_a_t = BlockLoader<
49
+ T,
50
+ transpose_a ? BK : BM,
51
+ transpose_a ? BM : BK,
52
+ transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
53
+ !transpose_a,
54
+ tgp_size>;
55
+ using loader_b_t = BlockLoader<
56
+ T,
57
+ transpose_b ? BN : BK,
58
+ transpose_b ? BK : BN,
59
+ transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
60
+ transpose_b,
61
+ tgp_size>;
62
+ using mma_t = BlockMMA<
63
+ T,
64
+ U,
65
+ BM,
66
+ BN,
67
+ BK,
68
+ WM,
69
+ WN,
70
+ transpose_a,
71
+ transpose_b,
72
+ transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
73
+ transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
74
+ AccumType,
75
+ Epilogue>;
76
+
77
+ /* Main kernel function */
78
+ template <bool M_aligned, bool N_aligned, bool K_aligned_>
79
+ static METAL_FUNC void gemm_loop(
80
+ threadgroup T* As [[threadgroup(0)]],
81
+ threadgroup T* Bs [[threadgroup(1)]],
82
+ const int gemm_k_iterations,
83
+ thread loader_a_t& loader_a,
84
+ thread loader_b_t& loader_b,
85
+ thread mma_t& mma_op,
86
+ thread const short& tgp_bm,
87
+ thread const short& tgp_bn,
88
+ thread const short& lbk,
89
+ LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
90
+ // Appease the compiler
91
+ (void)l;
92
+
93
+ short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
94
+
95
+ short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
96
+
97
+ for (int k = 0; k < gemm_k_iterations; k++) {
98
+ threadgroup_barrier(mem_flags::mem_threadgroup);
99
+ // Load elements into threadgroup
100
+ if (M_aligned) {
101
+ loader_a.load_unsafe();
102
+ } else {
103
+ loader_a.load_safe(tile_dims_A);
104
+ }
105
+
106
+ if (N_aligned) {
107
+ loader_b.load_unsafe();
108
+ } else {
109
+ loader_b.load_safe(tile_dims_B);
110
+ }
111
+
112
+ threadgroup_barrier(mem_flags::mem_threadgroup);
113
+
114
+ // Multiply and accumulate threadgroup elements
115
+ mma_op.mma(As, Bs);
116
+
117
+ // Prepare for next iteration
118
+ loader_a.next();
119
+ loader_b.next();
120
+ }
121
+
122
+ if (!K_aligned_) {
123
+ threadgroup_barrier(mem_flags::mem_threadgroup);
124
+
125
+ short2 tile_dims_A_last =
126
+ transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
127
+ short2 tile_dims_B_last =
128
+ transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
129
+
130
+ loader_a.load_safe(tile_dims_A_last);
131
+ loader_b.load_safe(tile_dims_B_last);
132
+
133
+ threadgroup_barrier(mem_flags::mem_threadgroup);
134
+
135
+ mma_op.mma(As, Bs);
136
+ }
137
+ }
138
+
139
+ /* Main kernel function */
140
+ static METAL_FUNC void run(
141
+ const device T* A [[buffer(0)]],
142
+ const device T* B [[buffer(1)]],
143
+ device U* D [[buffer(2)]],
144
+ const constant GEMMParams* params [[buffer(3)]],
145
+ threadgroup T* As [[threadgroup(0)]],
146
+ threadgroup T* Bs [[threadgroup(1)]],
147
+ uint simd_lane_id [[thread_index_in_simdgroup]],
148
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
149
+ uint3 tid [[threadgroup_position_in_grid]],
150
+ uint3 lid [[thread_position_in_threadgroup]]) {
151
+ // Pacifying compiler
152
+ (void)lid;
153
+
154
+ const int tid_y = ((tid.y) << params->swizzle_log) +
155
+ ((tid.x) & ((1 << params->swizzle_log) - 1));
156
+ const int tid_x = (tid.x) >> params->swizzle_log;
157
+
158
+ if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
159
+ return;
160
+ }
161
+
162
+ threadgroup_barrier(mem_flags::mem_none);
163
+
164
+ // Find block in A, B, C
165
+ const int c_row = tid_y * BM;
166
+ const int c_col = tid_x * BN;
167
+ const size_t c_row_long = size_t(c_row);
168
+ const size_t c_col_long = size_t(c_col);
169
+
170
+ A += transpose_a ? c_row_long : c_row_long * params->lda;
171
+ B += transpose_b ? c_col_long * params->ldb : c_col_long;
172
+ D += c_row_long * params->ldd + c_col_long;
173
+
174
+ // Prepare threadgroup loading operations
175
+ thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
176
+ thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
177
+
178
+ // Prepare threadgroup mma operation
179
+ thread mma_t mma_op(simd_group_id, simd_lane_id);
180
+
181
+ int gemm_k_iterations = params->gemm_k_iterations_aligned;
182
+
183
+ ///////////////////////////////////////////////////////////////////////////////
184
+ // MNK aligned loop
185
+ if (MN_aligned) {
186
+ for (int k = 0; k < gemm_k_iterations; k++) {
187
+ threadgroup_barrier(mem_flags::mem_threadgroup);
188
+ // Load elements into threadgroup
189
+ loader_a.load_unsafe();
190
+ loader_b.load_unsafe();
191
+
192
+ threadgroup_barrier(mem_flags::mem_threadgroup);
193
+
194
+ // Multiply and accumulate threadgroup elements
195
+ mma_op.mma(As, Bs);
196
+
197
+ // Prepare for next iteration
198
+ loader_a.next();
199
+ loader_b.next();
200
+ }
201
+
202
+ threadgroup_barrier(mem_flags::mem_none);
203
+
204
+ // Loop tail
205
+ if (!K_aligned) {
206
+ int lbk = params->K - params->gemm_k_iterations_aligned * BK;
207
+ short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
208
+ short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
209
+
210
+ loader_a.load_safe(tile_dims_A);
211
+ loader_b.load_safe(tile_dims_B);
212
+
213
+ threadgroup_barrier(mem_flags::mem_threadgroup);
214
+
215
+ mma_op.mma(As, Bs);
216
+ }
217
+
218
+ // Store results to device memory
219
+ mma_op.store_result(D, params->ldd);
220
+ return;
221
+
222
+ }
223
+ ///////////////////////////////////////////////////////////////////////////////
224
+ // MN unaligned loop
225
+ else { // Loop over K - unaligned case
226
+ short tgp_bm = min(BM, params->M - c_row);
227
+ short tgp_bn = min(BN, params->N - c_col);
228
+ short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
229
+
230
+ if (tgp_bm == BM && tgp_bn == BN) {
231
+ gemm_loop<true, true, K_aligned>(
232
+ As,
233
+ Bs,
234
+ gemm_k_iterations,
235
+ loader_a,
236
+ loader_b,
237
+ mma_op,
238
+ tgp_bm,
239
+ tgp_bn,
240
+ leftover_bk);
241
+
242
+ mma_op.store_result(D, params->ldd);
243
+ return;
244
+
245
+ } else if (tgp_bn == BN) {
246
+ gemm_loop<false, true, K_aligned>(
247
+ As,
248
+ Bs,
249
+ gemm_k_iterations,
250
+ loader_a,
251
+ loader_b,
252
+ mma_op,
253
+ tgp_bm,
254
+ tgp_bn,
255
+ leftover_bk);
256
+
257
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
258
+ return;
259
+
260
+ } else if (tgp_bm == BM) {
261
+ gemm_loop<true, false, K_aligned>(
262
+ As,
263
+ Bs,
264
+ gemm_k_iterations,
265
+ loader_a,
266
+ loader_b,
267
+ mma_op,
268
+ tgp_bm,
269
+ tgp_bn,
270
+ leftover_bk);
271
+
272
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
273
+ return;
274
+
275
+ } else {
276
+ gemm_loop<false, false, K_aligned>(
277
+ As,
278
+ Bs,
279
+ gemm_k_iterations,
280
+ loader_a,
281
+ loader_b,
282
+ mma_op,
283
+ tgp_bm,
284
+ tgp_bn,
285
+ leftover_bk);
286
+
287
+ mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
288
+ return;
289
+ }
290
+ }
291
+ }
292
+ };
293
+
294
+ } // namespace steel
295
+ } // namespace mlx
bitsandbytes_mps/gemm/loader.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "gemm/defines.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Loading helper
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <
15
+ typename T,
16
+ short BROWS,
17
+ short BCOLS,
18
+ short dst_ld,
19
+ short reduction_dim,
20
+ short tgp_size,
21
+ short alignment = 1,
22
+ short n_reads = (BCOLS * BROWS) / (tgp_size),
23
+ short TCOLS = BCOLS / n_reads,
24
+ short TROWS = tgp_size / TCOLS>
25
+ struct BlockLoader {
26
+ STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27
+ STEEL_CONST short vec_size = n_reads;
28
+
29
+ // Leading dimension for src
30
+ const int src_ld;
31
+ const int tile_stride;
32
+
33
+ // Thread location indices
34
+ const short thread_idx;
35
+ const short bi;
36
+ const short bj;
37
+
38
+ // threadgroup and device memory
39
+ threadgroup T* dst;
40
+ const device T* src;
41
+
42
+ struct alignas(alignment * sizeof(T)) ReadVector {
43
+ uint8_t v[sizeof(T) * vec_size];
44
+ };
45
+
46
+ /* Constructor */
47
+ METAL_FUNC BlockLoader(
48
+ const device T* src_,
49
+ const int src_ld_,
50
+ threadgroup T* dst_,
51
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
53
+ : src_ld(src_ld_),
54
+ tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55
+ thread_idx(simd_group_id * 32 + simd_lane_id),
56
+ bi(thread_idx / TCOLS),
57
+ bj(vec_size * (thread_idx % TCOLS)),
58
+ dst(dst_ + bi * dst_ld + bj),
59
+ src(src_ + bi * src_ld + bj) {}
60
+
61
+ /* Apply operation to threadgroup without bound checking */
62
+ template <typename UnaryOp>
63
+ METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
64
+ STEEL_PRAGMA_UNROLL
65
+ for (short i = 0; i < BROWS; i += TROWS) {
66
+ STEEL_PRAGMA_UNROLL
67
+ for (short j = 0; j < vec_size; j++) {
68
+ dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
69
+ }
70
+ }
71
+ }
72
+
73
+ /* Load from device memory into threadgroup memory - without bound checking */
74
+ METAL_FUNC void load_unsafe() const {
75
+ STEEL_PRAGMA_UNROLL
76
+ for (short i = 0; i < BROWS; i += TROWS) {
77
+ *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
78
+ *((const device ReadVector*)(&src[i * src_ld]));
79
+ }
80
+ }
81
+
82
+ /* Load from device memory into threadgroup memory - with bound checking */
83
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
84
+ src_tile_dim = src_tile_dim - short2(bj, bi);
85
+
86
+ // Skip loading if thread has no valid reads
87
+ if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short i = 0; i < BROWS; i += TROWS) {
90
+ STEEL_PRAGMA_UNROLL
91
+ for (short j = 0; j < vec_size; j++) {
92
+ dst[i * dst_ld + j] = T(0);
93
+ }
94
+ }
95
+ return;
96
+ }
97
+
98
+ // Use fast thread memory for bound checks
99
+ bool tmp_idx[vec_size];
100
+ T tmp_val[vec_size];
101
+
102
+ STEEL_PRAGMA_UNROLL
103
+ for (short i = 0; i < BROWS; i += TROWS) {
104
+ // Make sure tmp_idx only contains valid indices
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short j = 0; j < vec_size; j++) {
107
+ tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
108
+ }
109
+
110
+ // Read valid indices into tmp_val
111
+ STEEL_PRAGMA_UNROLL
112
+ for (short j = 0; j < vec_size; j++) {
113
+ tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
114
+ }
115
+
116
+ // Zero out unneeded values
117
+ STEEL_PRAGMA_UNROLL
118
+ for (short j = 0; j < vec_size; j++) {
119
+ tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
120
+ }
121
+
122
+ // Copy values to threadgroup memory
123
+ STEEL_PRAGMA_UNROLL
124
+ for (short j = 0; j < vec_size; j++) {
125
+ dst[i * dst_ld + j] = tmp_val[j];
126
+ }
127
+ }
128
+ }
129
+
130
+ /* Iteration helper */
131
+ METAL_FUNC void next() {
132
+ src += tile_stride;
133
+ }
134
+ };
135
+
136
+ } // namespace steel
137
+ } // namespace mlx
bitsandbytes_mps/gemm/mma.h ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_simdgroup>
6
+ #include <metal_simdgroup_matrix>
7
+ #include <metal_stdlib>
8
+
9
+ #include "gemm/defines.h"
10
+ #include "gemm/transforms.h"
11
+ #include "gemm/utils/integral_constant.h"
12
+
13
+ using namespace metal;
14
+
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+ // MMA helper
17
+ ///////////////////////////////////////////////////////////////////////////////
18
+
19
+ namespace mlx {
20
+ namespace steel {
21
+
22
+ template <typename T, int kFragRows_, int kFragCols_>
23
+ struct BaseMMAFrag {
24
+ static_assert(
25
+ kFragRows_ == 8,
26
+ "Only 8 x 8 fragment matrices are currently supported");
27
+ static_assert(
28
+ kFragCols_ == 8,
29
+ "Only 8 x 8 fragment matrices are currently supported");
30
+ };
31
+
32
+ template <typename T>
33
+ struct BaseMMAFrag<T, 8, 8> {
34
+ STEEL_CONST int kFragRows = 8;
35
+ STEEL_CONST int kFragCols = 8;
36
+
37
+ STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
38
+
39
+ STEEL_CONST int kElemRows = 1;
40
+ STEEL_CONST int kElemCols = 2;
41
+
42
+ static_assert(
43
+ kElemRows * kElemCols == kElemsPerFrag,
44
+ "MMAFrag shape is not consistent with MMAFrag size");
45
+
46
+ typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
47
+ typedef metal::vec<T, kElemsPerFrag> frag_type;
48
+
49
+ METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
50
+ [[thread_index_in_simdgroup]]) {
51
+ const short qid = simd_lane_id / 4;
52
+ const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
53
+ const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
54
+ return short2{fn, fm};
55
+ }
56
+
57
+ template <typename SrcPtrType, typename StrX, typename StrY>
58
+ METAL_FUNC static constexpr void
59
+ load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
60
+ STEEL_PRAGMA_UNROLL
61
+ for (short i = 0; i < kElemRows; i++) {
62
+ STEEL_PRAGMA_UNROLL
63
+ for (short j = 0; j < kElemCols; j++) {
64
+ dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
65
+ }
66
+ }
67
+ }
68
+
69
+ template <
70
+ typename SrcPtrType,
71
+ typename StrX,
72
+ typename StrY,
73
+ typename LimX,
74
+ typename LimY,
75
+ typename OffX,
76
+ typename OffY>
77
+ METAL_FUNC static constexpr void load_safe(
78
+ thread frag_type& dst,
79
+ SrcPtrType src,
80
+ StrX str_x,
81
+ StrY str_y,
82
+ LimX lim_x,
83
+ LimY lim_y,
84
+ OffX off_x = Int<0>{},
85
+ OffY off_y = Int<0>{}) {
86
+ STEEL_PRAGMA_UNROLL
87
+ for (short i = 0; i < kElemRows; i++) {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short j = 0; j < kElemCols; j++) {
90
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
91
+ dst[i * kElemCols + j] =
92
+ static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
93
+ } else {
94
+ dst[i * kElemCols + j] = T(0);
95
+ }
96
+ }
97
+ }
98
+ }
99
+
100
+ template <typename DstPtrType, typename StrX, typename StrY>
101
+ METAL_FUNC static constexpr void
102
+ store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
103
+ using U = pointer_element_t<DstPtrType>;
104
+
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short i = 0; i < kElemRows; i++) {
107
+ STEEL_PRAGMA_UNROLL
108
+ for (short j = 0; j < kElemCols; j++) {
109
+ dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
110
+ }
111
+ }
112
+ }
113
+
114
+ template <
115
+ typename DstPtrType,
116
+ typename StrX,
117
+ typename StrY,
118
+ typename LimX,
119
+ typename LimY,
120
+ typename OffX,
121
+ typename OffY>
122
+ METAL_FUNC static constexpr void store_safe(
123
+ const thread frag_type& src,
124
+ DstPtrType dst,
125
+ StrX str_x,
126
+ StrY str_y,
127
+ LimX lim_x,
128
+ LimY lim_y,
129
+ OffX off_x = Int<0>{},
130
+ OffY off_y = Int<0>{}) {
131
+ using U = pointer_element_t<DstPtrType>;
132
+
133
+ STEEL_PRAGMA_UNROLL
134
+ for (short i = 0; i < kElemRows; i++) {
135
+ STEEL_PRAGMA_UNROLL
136
+ for (short j = 0; j < kElemCols; j++) {
137
+ if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
138
+ dst[(off_x + i) * str_x + (off_y + j) * str_y] =
139
+ static_cast<U>(src[i * kElemCols + j]);
140
+ }
141
+ }
142
+ }
143
+ }
144
+
145
+ template <
146
+ typename DstPtrType,
147
+ typename StrX,
148
+ typename StrY,
149
+ typename StartX,
150
+ typename StopX,
151
+ typename StartY,
152
+ typename StopY,
153
+ typename OffX,
154
+ typename OffY>
155
+ METAL_FUNC static constexpr void store_slice(
156
+ const thread frag_type& src,
157
+ DstPtrType dst,
158
+ StrX str_x,
159
+ StrY str_y,
160
+ StartX start_x,
161
+ StopX stop_x,
162
+ StartY start_y,
163
+ StopY stop_y,
164
+ OffX off_x = Int<0>{},
165
+ OffY off_y = Int<0>{}) {
166
+ using U = pointer_element_t<DstPtrType>;
167
+
168
+ STEEL_PRAGMA_UNROLL
169
+ for (short i = 0; i < kElemRows; i++) {
170
+ STEEL_PRAGMA_UNROLL
171
+ for (short j = 0; j < kElemCols; j++) {
172
+ if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
173
+ (off_y + j) < stop_y && (off_y + j) >= start_y) {
174
+ dst[(off_x + i) * str_x + (off_y + j) * str_y] =
175
+ static_cast<U>(src[i * kElemCols + j]);
176
+ }
177
+ }
178
+ }
179
+ }
180
+
181
+ METAL_FUNC static constexpr void mma(
182
+ thread frag_type& D,
183
+ thread frag_type& A,
184
+ thread frag_type& B,
185
+ thread frag_type& C) {
186
+ mat_type D_mat;
187
+ mat_type A_mat;
188
+ mat_type B_mat;
189
+ mat_type C_mat;
190
+
191
+ reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
192
+ reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
193
+ reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
194
+
195
+ mma(D_mat, A_mat, B_mat, C_mat);
196
+
197
+ D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
198
+ }
199
+
200
+ METAL_FUNC static constexpr void mma(
201
+ thread mat_type& D,
202
+ thread mat_type& A,
203
+ thread mat_type& B,
204
+ thread mat_type& C) {
205
+ simdgroup_multiply_accumulate(D, A, B, C);
206
+ }
207
+ };
208
+
209
+ template <
210
+ typename T,
211
+ int kTileRows_,
212
+ int kTileCols_,
213
+ class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
214
+ struct MMATile {
215
+ using MMAFrag_t = MMAFrag_;
216
+ using elem_type = T;
217
+ STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
218
+ STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
219
+ STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
220
+
221
+ STEEL_CONST int kTileRows = kTileRows_;
222
+ STEEL_CONST int kTileCols = kTileCols_;
223
+
224
+ STEEL_CONST int kRows = kTileRows * kFragRows;
225
+ STEEL_CONST int kCols = kTileCols * kFragCols;
226
+
227
+ STEEL_CONST int kNumFrags = kTileRows * kTileCols;
228
+ STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
229
+
230
+ typedef typename MMAFrag_t::mat_type mat_type;
231
+ typedef typename MMAFrag_t::frag_type frag_type;
232
+
233
+ frag_type val_frags[kNumFrags] = {frag_type(0)};
234
+
235
+ METAL_FUNC MMATile() thread {}
236
+
237
+ METAL_FUNC constexpr void clear() {
238
+ STEEL_PRAGMA_UNROLL
239
+ for (short i = 0; i < kNumFrags; ++i) {
240
+ val_frags[i] = frag_type(0);
241
+ }
242
+ }
243
+
244
+ METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
245
+ return val_frags[i * kTileCols + j];
246
+ }
247
+
248
+ METAL_FUNC constexpr const thread frag_type& frag_at(
249
+ const short i,
250
+ const short j) const {
251
+ return val_frags[i * kTileCols + j];
252
+ }
253
+
254
+ METAL_FUNC mat_type mat_at(const short i, const short j) {
255
+ mat_type val_mat;
256
+ STEEL_PRAGMA_UNROLL
257
+ for (short ii = 0; ii < kElemsPerFrag; ++ii) {
258
+ val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
259
+ }
260
+ return val_mat;
261
+ }
262
+
263
+ METAL_FUNC thread elem_type* elems() {
264
+ return reinterpret_cast<thread elem_type*>(val_frags);
265
+ }
266
+
267
+ METAL_FUNC const thread elem_type* elems() const {
268
+ return reinterpret_cast<const thread elem_type*>(val_frags);
269
+ }
270
+
271
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
272
+ METAL_FUNC void load(const threadgroup U* src) {
273
+ STEEL_PRAGMA_UNROLL
274
+ for (short i = 0; i < kTileRows; ++i) {
275
+ STEEL_PRAGMA_UNROLL
276
+ for (short j = 0; j < kTileCols; ++j) {
277
+ MMAFrag_t::load(
278
+ frag_at(i, j),
279
+ &(
280
+ src[(i * kFragRows) * w_x * str_x +
281
+ (j * kFragCols) * w_y * str_y]),
282
+ Int<str_x>{},
283
+ Int<str_y>{});
284
+ }
285
+ }
286
+ }
287
+
288
+ template <typename U, int w_x, int w_y, int str_x, int str_y>
289
+ METAL_FUNC void store(threadgroup U* dst) const {
290
+ STEEL_PRAGMA_UNROLL
291
+ for (short i = 0; i < kTileRows; ++i) {
292
+ STEEL_PRAGMA_UNROLL
293
+ for (short j = 0; j < kTileCols; ++j) {
294
+ MMAFrag_t::store(
295
+ frag_at(i, j),
296
+ &(
297
+ dst[(i * kFragRows) * w_x * str_x +
298
+ (j * kFragCols) * w_y * str_y]),
299
+ Int<str_x>{},
300
+ Int<str_y>{});
301
+ }
302
+ }
303
+ }
304
+
305
+ template <typename U, int w_x, int w_y>
306
+ METAL_FUNC void load(const device U* src, const int ld) {
307
+ STEEL_PRAGMA_UNROLL
308
+ for (short i = 0; i < kTileRows; ++i) {
309
+ STEEL_PRAGMA_UNROLL
310
+ for (short j = 0; j < kTileCols; ++j) {
311
+ MMAFrag_t::load(
312
+ frag_at(i, j),
313
+ &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
314
+ ld,
315
+ Int<1>{});
316
+ }
317
+ }
318
+ }
319
+
320
+ template <typename U, int w_x, int w_y>
321
+ METAL_FUNC void store(device U* dst, const int ld) const {
322
+ STEEL_PRAGMA_UNROLL
323
+ for (short i = 0; i < kTileRows; ++i) {
324
+ STEEL_PRAGMA_UNROLL
325
+ for (short j = 0; j < kTileCols; ++j) {
326
+ MMAFrag_t::store(
327
+ frag_at(i, j),
328
+ &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
329
+ ld,
330
+ Int<1>{});
331
+ }
332
+ }
333
+ }
334
+
335
+ template <typename U, int w_x, int w_y>
336
+ METAL_FUNC void
337
+ load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
338
+ STEEL_PRAGMA_UNROLL
339
+ for (int i = 0; i < kTileRows; ++i) {
340
+ STEEL_PRAGMA_UNROLL
341
+ for (int j = 0; j < kTileCols; ++j) {
342
+ MMAFrag_t::load_safe(
343
+ frag_at(i, j),
344
+ src,
345
+ ld,
346
+ Int<1>{},
347
+ src_tile_dims.y,
348
+ src_tile_dims.x,
349
+ (i * kFragRows) * w_x,
350
+ (j * kFragCols) * w_y);
351
+ }
352
+ }
353
+ }
354
+
355
+ template <typename U, int w_x, int w_y>
356
+ METAL_FUNC void
357
+ store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
358
+ STEEL_PRAGMA_UNROLL
359
+ for (int i = 0; i < kTileRows; ++i) {
360
+ STEEL_PRAGMA_UNROLL
361
+ for (int j = 0; j < kTileCols; ++j) {
362
+ MMAFrag_t::store_safe(
363
+ frag_at(i, j),
364
+ dst,
365
+ ld,
366
+ Int<1>{},
367
+ dst_tile_dims.y,
368
+ dst_tile_dims.x,
369
+ (i * kFragRows) * w_x,
370
+ (j * kFragCols) * w_y);
371
+ }
372
+ }
373
+ }
374
+
375
+ template <typename U, int w_x, int w_y>
376
+ METAL_FUNC void store_slice(
377
+ device U* dst,
378
+ const int ld,
379
+ const short2 start,
380
+ const short2 stop) const {
381
+ STEEL_PRAGMA_UNROLL
382
+ for (int i = 0; i < kTileRows; ++i) {
383
+ STEEL_PRAGMA_UNROLL
384
+ for (int j = 0; j < kTileCols; ++j) {
385
+ MMAFrag_t::store_slice(
386
+ frag_at(i, j),
387
+ dst,
388
+ ld,
389
+ Int<1>{},
390
+ start.y,
391
+ stop.y,
392
+ start.x,
393
+ stop.x,
394
+ (i * kFragRows) * w_x,
395
+ (j * kFragCols) * w_y);
396
+ }
397
+ }
398
+ }
399
+ };
400
+
401
+ template <typename T, typename U, int M, int N, int K>
402
+ METAL_FUNC void tile_matmad(
403
+ thread MMATile<T, M, N>& D,
404
+ thread MMATile<U, M, K>& A,
405
+ thread MMATile<U, K, N>& B,
406
+ thread MMATile<T, M, N>& C) {
407
+ STEEL_PRAGMA_UNROLL
408
+ for (short m = 0; m < M; ++m) {
409
+ STEEL_PRAGMA_UNROLL
410
+ for (short n = 0; n < N; ++n) {
411
+ short n_serp = (m % 2) ? (N - 1 - n) : n;
412
+ STEEL_PRAGMA_UNROLL
413
+ for (short k = 0; k < K; ++k) {
414
+ MMATile<T, M, N>::MMAFrag_t::mma(
415
+ D.frag_at(m, n_serp),
416
+ A.frag_at(m, k),
417
+ B.frag_at(k, n_serp),
418
+ C.frag_at(m, n_serp));
419
+ }
420
+ }
421
+ }
422
+ }
423
+
424
+ template <
425
+ typename T,
426
+ typename U,
427
+ int BM,
428
+ int BN,
429
+ int BK,
430
+ int WM,
431
+ int WN,
432
+ bool transpose_a,
433
+ bool transpose_b,
434
+ short lda_tgp,
435
+ short ldb_tgp,
436
+ typename AccumType = float,
437
+ typename Epilogue = TransformNone<U, AccumType>>
438
+ struct BlockMMA {
439
+ // MMAFrag size
440
+ STEEL_CONST short kFragSize = 8;
441
+ using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
442
+
443
+ // Warp tile simdgroup matrix strides along M
444
+ STEEL_CONST short TM_stride = kFragSize * WM;
445
+ // Warp tile simdgroup matrix strides along M
446
+ STEEL_CONST short TN_stride = kFragSize * WN;
447
+
448
+ // Warp tile size along M
449
+ STEEL_CONST short TM = BM / (kFragSize * WM);
450
+ // Warp tile size along N
451
+ STEEL_CONST short TN = BN / (kFragSize * WN);
452
+
453
+ // Threadgroup A strides
454
+ STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
455
+ STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
456
+
457
+ // Threadgroup B strides
458
+ STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
459
+ STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
460
+
461
+ // Threadgroup strides along K
462
+ STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
463
+ STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
464
+
465
+ // Simdgroup matrices
466
+ MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
467
+ MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
468
+ MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
469
+
470
+ // Offsets within threadgroup
471
+ short sm;
472
+ short sn;
473
+
474
+ short As_offset;
475
+ short Bs_offset;
476
+
477
+ /* Constructor */
478
+ METAL_FUNC BlockMMA(
479
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
480
+ ushort simd_lane_id [[thread_index_in_simdgroup]]) {
481
+ // Determine thread position in simdgroup matrix
482
+ short tm = kFragSize * (simd_group_id / WN);
483
+ short tn = kFragSize * (simd_group_id % WN);
484
+
485
+ short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
486
+ sm = simd_coord.y;
487
+ sn = simd_coord.x;
488
+
489
+ // Determine thread and simdgroup offset
490
+ As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
491
+ Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
492
+
493
+ sm += tm;
494
+ sn += tn;
495
+ }
496
+
497
+ /* (BM, BK) X (BK, BN) multiply accumulate function */
498
+ METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
499
+ // Adjust for simdgroup and thread location
500
+ As += As_offset;
501
+ Bs += Bs_offset;
502
+
503
+ // Iterate over BK in blocks of kFragSize
504
+ STEEL_PRAGMA_UNROLL
505
+ for (short kk = 0; kk < BK; kk += kFragSize) {
506
+ simdgroup_barrier(mem_flags::mem_none);
507
+
508
+ Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
509
+
510
+ simdgroup_barrier(mem_flags::mem_none);
511
+
512
+ Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
513
+
514
+ simdgroup_barrier(mem_flags::mem_none);
515
+
516
+ tile_matmad(Ctile, Atile, Btile, Ctile);
517
+
518
+ // Progress to next simdgroup tile
519
+ As += tile_stride_a;
520
+ Bs += tile_stride_b;
521
+ }
522
+ }
523
+
524
+ /* Store results from simdgroup_matrix results into device memory */
525
+ METAL_FUNC void store_result(device U* D, const int ldd) {
526
+ // Apply epilogue
527
+ STEEL_PRAGMA_UNROLL
528
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
529
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
530
+ }
531
+
532
+ // Adjust for simdgroup and thread location
533
+ D += sm * ldd + sn;
534
+
535
+ Ctile.template store<U, WM, WN>(D, ldd);
536
+ }
537
+
538
+ METAL_FUNC void
539
+ store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
540
+ // Apply epilogue
541
+ STEEL_PRAGMA_UNROLL
542
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
543
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
544
+ }
545
+
546
+ D += sm * ldd + sn;
547
+ start -= short2(sn, sm);
548
+ stop -= short2(sn, sm);
549
+
550
+ // TODO: Check the start as well
551
+ if (stop.y <= 0 || stop.x <= 0) {
552
+ return;
553
+ }
554
+
555
+ Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
556
+ }
557
+
558
+ METAL_FUNC void
559
+ store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
560
+ // Apply epilogue
561
+ STEEL_PRAGMA_UNROLL
562
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
563
+ Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
564
+ }
565
+
566
+ // Adjust for simdgroup and thread location
567
+ D += sm * ldd + sn;
568
+ dst_tile_dims -= short2(sn, sm);
569
+
570
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
571
+ return;
572
+
573
+ Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
574
+ }
575
+
576
+ /* Apply epilogue */
577
+ template <typename UnaryEpilogue>
578
+ METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
579
+ // Loop over all simdgroup tiles
580
+ STEEL_PRAGMA_UNROLL
581
+ for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
582
+ Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
583
+ }
584
+ }
585
+
586
+ /* Apply epilogue */
587
+ template <typename BinaryEpilogue>
588
+ METAL_FUNC void apply_epilogue(
589
+ const device U* C,
590
+ const int ldc,
591
+ const int fdc,
592
+ thread const BinaryEpilogue& epilogue_op) {
593
+ // Adjust for simdgroup and thread location
594
+ C += (sm)*ldc + (sn)*fdc;
595
+
596
+ // Loop over all simdgroup tiles
597
+ STEEL_PRAGMA_UNROLL
598
+ for (short i = 0; i < TM; i++) {
599
+ STEEL_PRAGMA_UNROLL
600
+ for (short j = 0; j < TN; j++) {
601
+ // Get accumulated result and associated offset in C
602
+ thread auto& accum = Ctile.frag_at(i, j);
603
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
604
+
605
+ // Apply epilogue
606
+ STEEL_PRAGMA_UNROLL
607
+ for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
608
+ accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
609
+ }
610
+ }
611
+ }
612
+ }
613
+
614
+ /* Apply epilogue */
615
+ template <typename BinaryEpilogue>
616
+ METAL_FUNC void apply_epilogue_safe(
617
+ const device U* C,
618
+ const int ldc,
619
+ const int fdc,
620
+ short2 dst_tile_dims,
621
+ thread const BinaryEpilogue& epilogue_op) {
622
+ // Adjust for simdgroup and thread location
623
+ C += (sm)*ldc + (sn)*fdc;
624
+ dst_tile_dims -= short2(sn, sm);
625
+
626
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
627
+ return;
628
+
629
+ // Loop over all simdgroup tiles
630
+ STEEL_PRAGMA_UNROLL
631
+ for (short i = 0; i < TM; i++) {
632
+ STEEL_PRAGMA_UNROLL
633
+ for (short j = 0; j < TN; j++) {
634
+ // Get accumulated result and associated offset in C
635
+ thread auto& accum = Ctile.frag_at(i, j);
636
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
637
+
638
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
639
+
640
+ // Read C
641
+ U c_elems[kelems] = {0};
642
+
643
+ STEEL_PRAGMA_UNROLL
644
+ for (short k = 0; k < kelems; k++) {
645
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
646
+ c_elems[k] = C[offset_c + k * fdc];
647
+ }
648
+ }
649
+
650
+ // Apply epilogue
651
+ STEEL_PRAGMA_UNROLL
652
+ for (short k = 0; k < kelems; k++) {
653
+ accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
654
+ }
655
+ }
656
+ }
657
+ }
658
+
659
+ /* Store results from simdgroup_matrix results into device memory */
660
+ METAL_FUNC void store_result(
661
+ device U* D,
662
+ const int ldd,
663
+ const device U* C,
664
+ const int ldc,
665
+ const int fdc,
666
+ thread const Epilogue& epilogue_op) const {
667
+ // Adjust for simdgroup and thread location
668
+ C += (sm)*ldc + (sn)*fdc;
669
+ D += (sm)*ldd + sn;
670
+
671
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
672
+
673
+ // Loop over all simdgroup tiles
674
+ STEEL_PRAGMA_UNROLL
675
+ for (short i = 0; i < TM; i++) {
676
+ STEEL_PRAGMA_UNROLL
677
+ for (short j = 0; j < TN; j++) {
678
+ // Get accumulated result and associated offset in C
679
+ thread const auto& accum = Ctile.frag_at(i, j);
680
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
681
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
682
+
683
+ // Apply epilogue
684
+ STEEL_PRAGMA_UNROLL
685
+ for (short k = 0; k < kelems; k++) {
686
+ D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
687
+ }
688
+ }
689
+ }
690
+ }
691
+
692
+ METAL_FUNC void store_result_safe(
693
+ device U* D,
694
+ const int ldd,
695
+ const device U* C,
696
+ const int ldc,
697
+ const int fdc,
698
+ short2 dst_tile_dims,
699
+ thread const Epilogue& epilogue_op) const {
700
+ // Adjust for simdgroup and thread location
701
+ C += (sm)*ldc + (sn)*fdc;
702
+ D += (sm)*ldd + sn;
703
+ dst_tile_dims -= short2(sn, sm);
704
+
705
+ if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
706
+ return;
707
+
708
+ constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
709
+
710
+ STEEL_PRAGMA_UNROLL
711
+ for (int i = 0; i < TM; i++) {
712
+ if (i * TM_stride < dst_tile_dims.y) {
713
+ STEEL_PRAGMA_UNROLL
714
+ for (int j = 0; j < TN; j++) {
715
+ // Get accumulated result and associated offset in C
716
+ thread const auto& accum = Ctile.frag_at(i, j);
717
+ int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
718
+ int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
719
+
720
+ // Apply epilogue
721
+ STEEL_PRAGMA_UNROLL
722
+ for (short k = 0; k < kelems; k++) {
723
+ if ((j * TN_stride + k) < dst_tile_dims.x) {
724
+ D[offset_d + k] =
725
+ epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
726
+ }
727
+ }
728
+ }
729
+ }
730
+ }
731
+ }
732
+ };
733
+
734
+ } // namespace steel
735
+ } // namespace mlx
bitsandbytes_mps/gemm/params.h ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ ///////////////////////////////////////////////////////////////////////////////
6
+ // GEMM param classes
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+
9
+ namespace mlx {
10
+ namespace steel {
11
+
12
+ struct GEMMParams {
13
+ const int M;
14
+ const int N;
15
+ const int K;
16
+
17
+ const int lda;
18
+ const int ldb;
19
+ const int ldd;
20
+
21
+ const int tiles_n;
22
+ const int tiles_m;
23
+
24
+ const int64_t batch_stride_a;
25
+ const int64_t batch_stride_b;
26
+ const int64_t batch_stride_d;
27
+
28
+ const int swizzle_log;
29
+ const int gemm_k_iterations_aligned;
30
+
31
+ const int batch_ndim;
32
+ };
33
+
34
+ struct GEMMSpiltKParams {
35
+ const int M;
36
+ const int N;
37
+ const int K;
38
+
39
+ const int lda;
40
+ const int ldb;
41
+ const int ldc;
42
+
43
+ const int tiles_n;
44
+ const int tiles_m;
45
+
46
+ const int split_k_partitions;
47
+ const int split_k_partition_stride;
48
+ const int split_k_partition_size;
49
+
50
+ const int gemm_k_iterations_aligned;
51
+ };
52
+
53
+ struct GEMMAddMMParams {
54
+ const int ldc;
55
+ const int fdc;
56
+
57
+ const int64_t batch_stride_c;
58
+
59
+ const float alpha;
60
+ const float beta;
61
+ };
62
+
63
+ } // namespace steel
64
+ } // namespace mlx
bitsandbytes_mps/gemm/transforms.h ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "gemm/utils.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Transforms and Epilogues
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <typename OutT, typename InT>
15
+ struct TransformNone {
16
+ static METAL_FUNC OutT apply(InT x) {
17
+ return static_cast<OutT>(x);
18
+ }
19
+
20
+ static METAL_FUNC OutT apply(InT x, OutT) {
21
+ return static_cast<OutT>(x);
22
+ }
23
+ };
24
+
25
+ template <typename OutT, typename InT>
26
+ struct TransformAdd {
27
+ TransformAdd(const float, const float) {}
28
+
29
+ static METAL_FUNC OutT apply(InT x) {
30
+ return static_cast<OutT>(x);
31
+ }
32
+
33
+ static METAL_FUNC OutT apply(InT x, OutT c) {
34
+ return static_cast<OutT>(x) + c;
35
+ }
36
+ };
37
+
38
+ template <typename OutT, typename InT>
39
+ struct TransformAxpby {
40
+ const float alpha;
41
+ const float beta;
42
+
43
+ TransformAxpby(const float alpha_, const float beta_)
44
+ : alpha(alpha_), beta(beta_) {}
45
+
46
+ static METAL_FUNC OutT apply(InT x) {
47
+ return static_cast<OutT>(x);
48
+ }
49
+
50
+ METAL_FUNC OutT apply(InT x, OutT c) const {
51
+ return static_cast<OutT>(
52
+ x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
53
+ }
54
+ };
55
+
56
+ template <typename T>
57
+ struct AccumHelper {
58
+ typedef float accum_type;
59
+ };
60
+
61
+ struct BlockSwizzle {
62
+ static METAL_FUNC int2
63
+ swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
64
+ const int tid_x = (tid.x) >> swizzle_log;
65
+ const int tid_y =
66
+ ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
67
+ return int2(tid_x, tid_y);
68
+ }
69
+ };
70
+
71
+ } // namespace steel
72
+ } // namespace mlx
bitsandbytes_mps/gemm/utils.h ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ METAL_FUNC ulong2 elem_to_loc_broadcast(
8
+ uint elem,
9
+ constant const int* shape,
10
+ constant const int64_t* a_strides,
11
+ constant const int64_t* b_strides,
12
+ int ndim) {
13
+ ulong loc_a{0};
14
+ ulong loc_b{0};
15
+ for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
16
+ int pos_in_dim = (elem % shape[i]);
17
+ elem /= shape[i];
18
+ loc_a += pos_in_dim * a_strides[i];
19
+ loc_b += pos_in_dim * b_strides[i];
20
+ }
21
+ return ulong2(loc_a, loc_b);
22
+ }
23
+
24
+ METAL_FUNC ulong3 elem_to_loc_broadcast(
25
+ uint elem,
26
+ constant const int* shape,
27
+ constant const int64_t* a_strides,
28
+ constant const int64_t* b_strides,
29
+ constant const int64_t* c_strides,
30
+ int ndim) {
31
+ ulong loc_a{0};
32
+ ulong loc_b{0};
33
+ ulong loc_c{0};
34
+ for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
35
+ int pos_in_dim = (elem % shape[i]);
36
+ elem /= shape[i];
37
+ loc_a += pos_in_dim * a_strides[i];
38
+ loc_b += pos_in_dim * b_strides[i];
39
+ loc_c += pos_in_dim * c_strides[i];
40
+ }
41
+ return ulong3(loc_a, loc_b, loc_c);
42
+ }
bitsandbytes_mps/gemm/utils/integral_constant.h ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+ #include "gemm/utils/type_traits.h"
7
+
8
+ #pragma METAL internals : enable
9
+
10
+ namespace mlx {
11
+ namespace steel {
12
+
13
+ ///////////////////////////////////////////////////////////////////////////////
14
+ // Integral constant with casting
15
+ ///////////////////////////////////////////////////////////////////////////////
16
+
17
+ template <typename T, T v>
18
+ struct integral_constant {
19
+ static constexpr constant T value = v;
20
+ using value_type = T;
21
+ using type = integral_constant;
22
+
23
+ METAL_FUNC constexpr operator value_type() const noexcept {
24
+ return value;
25
+ }
26
+
27
+ // METAL_FUNC constexpr value_type operator()() const noexcept {
28
+ // return value;
29
+ // }
30
+ };
31
+
32
+ template <bool B>
33
+ using bool_constant = integral_constant<bool, B>;
34
+ using true_type = bool_constant<true>;
35
+ using false_type = bool_constant<false>;
36
+
37
+ template <class T>
38
+ struct is_integral : bool_constant<metal::is_integral<T>::value> {};
39
+
40
+ template <class T, T v>
41
+ struct is_integral<integral_constant<T, v>>
42
+ : bool_constant<metal::is_integral<T>::value> {};
43
+
44
+ template <typename T>
45
+ constexpr constant bool is_integral_v = is_integral<T>::value;
46
+
47
+ template <int val>
48
+ using Int = integral_constant<int, val>;
49
+
50
+ ///////////////////////////////////////////////////////////////////////////////
51
+ // Binary Operators on Integral constants
52
+ ///////////////////////////////////////////////////////////////////////////////
53
+
54
+ #define integral_const_binop(__op__, __operator__) \
55
+ template <typename T, T tv, typename U, U uv> \
56
+ METAL_FUNC constexpr auto __operator__( \
57
+ integral_constant<T, tv>, integral_constant<U, uv>) { \
58
+ constexpr auto res = tv __op__ uv; \
59
+ return integral_constant<decltype(res), res>{}; \
60
+ }
61
+
62
+ integral_const_binop(+, operator+);
63
+ integral_const_binop(-, operator-);
64
+ integral_const_binop(*, operator*);
65
+ integral_const_binop(/, operator/);
66
+
67
+ integral_const_binop(==, operator==);
68
+ integral_const_binop(!=, operator!=);
69
+ integral_const_binop(<, operator<);
70
+ integral_const_binop(>, operator>);
71
+ integral_const_binop(<=, operator<=);
72
+ integral_const_binop(>=, operator>=);
73
+
74
+ integral_const_binop(&&, operator&&);
75
+ integral_const_binop(||, operator||);
76
+
77
+ template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
78
+ METAL_FUNC constexpr auto operator||(true_type, T) {
79
+ return true_type{};
80
+ }
81
+ template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
82
+ METAL_FUNC constexpr auto operator||(T, true_type) {
83
+ return true_type{};
84
+ }
85
+
86
+ template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
87
+ METAL_FUNC constexpr auto operator&&(false_type, T) {
88
+ return false_type{};
89
+ }
90
+
91
+ template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
92
+ METAL_FUNC constexpr auto operator&&(T, false_type) {
93
+ return false_type{};
94
+ }
95
+
96
+ // Dispatch utilities
97
+ template <typename F>
98
+ void dispatch_bool(bool v, F f) {
99
+ if (v) {
100
+ f(true_type{});
101
+ } else {
102
+ f(false_type{});
103
+ }
104
+ }
105
+
106
+ template <int start, int stop, int step, typename F>
107
+ constexpr void const_for_loop(F f) {
108
+ if constexpr (start < stop) {
109
+ constexpr auto idx = Int<start>{};
110
+ f(idx);
111
+ const_for_loop<start + step, stop, step, F>(f);
112
+ }
113
+ }
114
+
115
+ #undef integral_const_binop
116
+
117
+ ///////////////////////////////////////////////////////////////////////////////
118
+ // Reduction operators
119
+ ///////////////////////////////////////////////////////////////////////////////
120
+
121
+ template <typename T>
122
+ METAL_FUNC constexpr T sum(T x) {
123
+ return x;
124
+ }
125
+
126
+ template <typename T, typename... Us>
127
+ METAL_FUNC constexpr auto sum(T x, Us... us) {
128
+ return x + sum(us...);
129
+ }
130
+
131
+ } // namespace steel
132
+ } // namespace mlx
133
+
134
+ #pragma METAL internals : disable
bitsandbytes_mps/gemm/utils/type_traits.h ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ #pragma METAL internals : enable
8
+
9
+ namespace metal {
10
+
11
+ template <typename T>
12
+ struct is_empty : metal::bool_constant<__is_empty(T)> {};
13
+
14
+ #ifdef __cpp_variable_templates
15
+ template <typename T>
16
+ constexpr constant bool is_empty_v = is_empty<T>::value;
17
+ #endif
18
+
19
+ template <typename... Ts>
20
+ struct make_void {
21
+ typedef void type;
22
+ };
23
+
24
+ template <typename... Ts>
25
+ using void_t = typename make_void<Ts...>::type;
26
+
27
+ template <class T>
28
+ struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
29
+
30
+ template <typename T>
31
+ struct pointer_element {};
32
+
33
+ template <typename T>
34
+ struct pointer_element<thread T*> {
35
+ using type = remove_cv_t<T>;
36
+ };
37
+ template <typename T>
38
+ struct pointer_element<device T*> {
39
+ using type = remove_cv_t<T>;
40
+ };
41
+ template <typename T>
42
+ struct pointer_element<constant T*> {
43
+ using type = remove_cv_t<T>;
44
+ };
45
+ template <typename T>
46
+ struct pointer_element<threadgroup T*> {
47
+ using type = remove_cv_t<T>;
48
+ };
49
+
50
+ template <typename T>
51
+ using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
52
+
53
+ } // namespace metal
54
+
55
+ #pragma METAL internals : disable
bitsandbytes_mps/quantized_utils.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+ #include <metal_stdlib>
5
+
6
+ template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
7
+ METAL_FUNC void gemm_loop_aligned(
8
+ threadgroup T* As,
9
+ threadgroup T* Bs,
10
+ thread mma_t& mma_op,
11
+ thread loader_a_t& loader_a,
12
+ thread loader_b_t& loader_b,
13
+ const int k_iterations) {
14
+ for (int k = 0; k < k_iterations; k++) {
15
+ threadgroup_barrier(mem_flags::mem_threadgroup);
16
+
17
+ // Load elements into threadgroup memory
18
+ loader_a.load_unsafe();
19
+ loader_b.load_unsafe();
20
+
21
+ threadgroup_barrier(mem_flags::mem_threadgroup);
22
+
23
+ // Multiply and accumulate threadgroup elements
24
+ mma_op.mma(As, Bs);
25
+
26
+ // Prepare for next iteration
27
+ loader_a.next();
28
+ loader_b.next();
29
+ }
30
+ }
31
+
32
+ template <
33
+ bool rows_aligned,
34
+ bool cols_aligned,
35
+ bool transpose,
36
+ typename T,
37
+ typename mma_t,
38
+ typename loader_a_t,
39
+ typename loader_b_t>
40
+ METAL_FUNC void gemm_loop_unaligned(
41
+ threadgroup T* As,
42
+ threadgroup T* Bs,
43
+ thread mma_t& mma_op,
44
+ thread loader_a_t& loader_a,
45
+ thread loader_b_t& loader_b,
46
+ const int k_iterations,
47
+ const short tgp_bm,
48
+ const short tgp_bn,
49
+ const short tgp_bk) {
50
+ for (int k = 0; k < k_iterations; k++) {
51
+ threadgroup_barrier(mem_flags::mem_threadgroup);
52
+
53
+ // Load elements into threadgroup memory
54
+ if (rows_aligned) {
55
+ loader_a.load_unsafe();
56
+ } else {
57
+ loader_a.load_safe(short2(tgp_bk, tgp_bm));
58
+ }
59
+ if (cols_aligned) {
60
+ loader_b.load_unsafe();
61
+ } else {
62
+ loader_b.load_safe(
63
+ transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
64
+ }
65
+
66
+ threadgroup_barrier(mem_flags::mem_threadgroup);
67
+
68
+ // Multiply and accumulate threadgroup elements
69
+ mma_op.mma(As, Bs);
70
+
71
+ // Prepare for next iteration
72
+ loader_a.next();
73
+ loader_b.next();
74
+ }
75
+ }
76
+
77
+ template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
78
+ METAL_FUNC void gemm_loop_finalize(
79
+ threadgroup T* As,
80
+ threadgroup T* Bs,
81
+ thread mma_t& mma_op,
82
+ thread loader_a_t& loader_a,
83
+ thread loader_b_t& loader_b,
84
+ const short2 tile_a,
85
+ const short2 tile_b) {
86
+ loader_a.load_safe(tile_a);
87
+ loader_b.load_safe(tile_b);
88
+ threadgroup_barrier(mem_flags::mem_threadgroup);
89
+ mma_op.mma(As, Bs);
90
+ }
bitsandbytes_mps/utils.h ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_math>
6
+
7
+ #include "bf16.h"
8
+ #include "defines.h"
9
+
10
+ typedef half float16_t;
11
+
12
+ // Work per thread values for different types. The values here are expected to
13
+ // match get_work_per_thread in mlx/backend/metal/utils.h
14
+ template <typename U>
15
+ struct WorkPerThread {
16
+ static_assert(sizeof(U) <= 8, "Type too large");
17
+ static constexpr int constant n = 8 / sizeof(U);
18
+ };
19
+
20
+ ///////////////////////////////////////////////////////////////////////////////
21
+ // Type limits utils
22
+ ///////////////////////////////////////////////////////////////////////////////
23
+
24
+ template <typename U>
25
+ struct Limits {
26
+ static const constant U max = metal::numeric_limits<U>::max();
27
+ static const constant U min = metal::numeric_limits<U>::min();
28
+ static const constant U finite_max = metal::numeric_limits<U>::max();
29
+ static const constant U finite_min = metal::numeric_limits<U>::min();
30
+ };
31
+
32
+ #define instantiate_default_limit(type) \
33
+ template <> \
34
+ struct Limits<type> { \
35
+ static constexpr constant type max = metal::numeric_limits<type>::max(); \
36
+ static constexpr constant type min = metal::numeric_limits<type>::min(); \
37
+ static constexpr constant type finite_max = \
38
+ metal::numeric_limits<type>::max(); \
39
+ static constexpr constant type finite_min = \
40
+ metal::numeric_limits<type>::min(); \
41
+ };
42
+
43
+ instantiate_default_limit(uint8_t);
44
+ instantiate_default_limit(uint16_t);
45
+ instantiate_default_limit(uint32_t);
46
+ instantiate_default_limit(uint64_t);
47
+ instantiate_default_limit(int8_t);
48
+ instantiate_default_limit(int16_t);
49
+ instantiate_default_limit(int32_t);
50
+ instantiate_default_limit(int64_t);
51
+
52
+ #define instantiate_float_limit(type) \
53
+ template <> \
54
+ struct Limits<type> { \
55
+ static constexpr constant type max = \
56
+ metal::numeric_limits<type>::infinity(); \
57
+ static constexpr constant type min = \
58
+ -metal::numeric_limits<type>::infinity(); \
59
+ static constexpr constant type finite_max = \
60
+ metal::numeric_limits<type>::max(); \
61
+ static constexpr constant type finite_min = \
62
+ -metal::numeric_limits<type>::max(); \
63
+ };
64
+
65
+ instantiate_float_limit(half);
66
+ instantiate_float_limit(float);
67
+ instantiate_float_limit(bfloat16_t);
68
+
69
+ template <>
70
+ struct Limits<bool> {
71
+ static constexpr constant bool max = true;
72
+ static constexpr constant bool min = false;
73
+ };
74
+
75
+ // complex64_t specialization removed - not needed for BnB kernels
76
+
77
+ ///////////////////////////////////////////////////////////////////////////////
78
+ // Indexing utils
79
+ ///////////////////////////////////////////////////////////////////////////////
80
+
81
+ #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
82
+
83
+ ///////////////////////////////////////////////////////////////////////////////
84
+ // Single Array with generic dims
85
+
86
+ template <typename IdxT = int64_t>
87
+ METAL_FUNC IdxT elem_to_loc(
88
+ IdxT elem,
89
+ constant const int* shape,
90
+ constant const int64_t* strides,
91
+ int ndim) {
92
+ IdxT loc = 0;
93
+ for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
94
+ loc += (elem % shape[i]) * IdxT(strides[i]);
95
+ elem /= shape[i];
96
+ }
97
+ return loc;
98
+ }
99
+
100
+ // Non templated version to handle arbitrary dims
101
+ template <typename IdxT = int64_t>
102
+ METAL_FUNC IdxT elem_to_loc(
103
+ uint3 elem,
104
+ constant const int* shape,
105
+ constant const int64_t* strides,
106
+ int ndim) {
107
+ IdxT loc =
108
+ elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
109
+ for (int d = ndim - 3; d >= 0; --d) {
110
+ loc += (elem.z % shape[d]) * IdxT(strides[d]);
111
+ elem.z /= shape[d];
112
+ }
113
+ return loc;
114
+ }
115
+
116
+ ///////////////////////////////////////////////////////////////////////////////
117
+ // Single Array with fixed N dims
118
+
119
+ template <typename IdxT = int64_t>
120
+ METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
121
+ return elem * IdxT(stride);
122
+ }
123
+
124
+ template <typename IdxT = int64_t>
125
+ METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
126
+ return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
127
+ }
128
+
129
+ template <typename IdxT = int64_t>
130
+ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
131
+ return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
132
+ elem.z * IdxT(strides[0]);
133
+ }
134
+
135
+ ///////////////////////////////////////////////////////////////////////////////
136
+ // Multiple Arrays with generic dims
137
+
138
+ template <typename IdxT = int64_t>
139
+ METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
140
+ uint3 elem,
141
+ constant const int* shape,
142
+ constant const int64_t* a_strides,
143
+ constant const int64_t* b_strides,
144
+ int ndim) {
145
+ vec<IdxT, 2> loc = {
146
+ IdxT(
147
+ elem.x * IdxT(a_strides[ndim - 1]) +
148
+ IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
149
+ IdxT(
150
+ elem.x * IdxT(b_strides[ndim - 1]) +
151
+ elem.y * IdxT(b_strides[ndim - 2]))};
152
+ for (int d = ndim - 3; d >= 0; --d) {
153
+ uint l = elem.z % shape[d];
154
+ loc.x += l * IdxT(a_strides[d]);
155
+ loc.y += l * IdxT(b_strides[d]);
156
+ elem.z /= shape[d];
157
+ }
158
+ return loc;
159
+ }
160
+
161
+ template <typename IdxT = int64_t>
162
+ METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
163
+ uint3 elem,
164
+ constant const int* shape,
165
+ constant const int64_t* a_strides,
166
+ constant const int64_t* b_strides,
167
+ constant const int64_t* c_strides,
168
+ int ndim) {
169
+ vec<IdxT, 3> loc = {
170
+ IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
171
+ IdxT(elem.y * IdxT(a_strides[ndim - 2])),
172
+ IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
173
+ IdxT(elem.y * IdxT(b_strides[ndim - 2])),
174
+ IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
175
+ IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
176
+ for (int d = ndim - 3; d >= 0; --d) {
177
+ uint l = elem.z % shape[d];
178
+ loc.x += l * IdxT(a_strides[d]);
179
+ loc.y += l * IdxT(b_strides[d]);
180
+ loc.z += l * IdxT(c_strides[d]);
181
+ elem.z /= shape[d];
182
+ }
183
+ return loc;
184
+ }
185
+
186
+ ///////////////////////////////////////////////////////////////////////////////
187
+ // Elem to loc in a loop utils
188
+ ///////////////////////////////////////////////////////////////////////////////
189
+
190
+ template <int DIM, typename OffsetT = size_t, bool General = true>
191
+ struct LoopedElemToLoc {
192
+ int dim;
193
+ LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
194
+ OffsetT offset{0};
195
+ int index{0};
196
+
197
+ LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
198
+
199
+ void next(const constant int* shape, const constant int64_t* strides) {
200
+ if (dim == 0) {
201
+ return;
202
+ }
203
+ index++;
204
+ offset += OffsetT(strides[dim - 1]);
205
+ if (index >= shape[dim - 1]) {
206
+ index = 0;
207
+ inner_looper.next(shape, strides);
208
+ offset = inner_looper.offset;
209
+ }
210
+ }
211
+
212
+ void next(int n, const constant int* shape, const constant int64_t* strides) {
213
+ if (dim == 0) {
214
+ return;
215
+ }
216
+ index += n;
217
+ offset += n * OffsetT(strides[dim - 1]);
218
+
219
+ if (index >= shape[dim - 1]) {
220
+ int extra = index - shape[dim - 1];
221
+ if (extra >= shape[dim - 1]) {
222
+ inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
223
+ extra = extra % shape[dim - 1];
224
+ } else {
225
+ inner_looper.next(shape, strides);
226
+ }
227
+ index = 0;
228
+ offset = inner_looper.offset;
229
+ if (extra > 0) {
230
+ next(extra, shape, strides);
231
+ }
232
+ }
233
+ }
234
+
235
+ OffsetT location() {
236
+ return offset;
237
+ }
238
+ };
239
+
240
+ template <typename OffsetT>
241
+ struct LoopedElemToLoc<1, OffsetT, true> {
242
+ int dim;
243
+ OffsetT offset{0};
244
+ uint index{0};
245
+
246
+ LoopedElemToLoc(int dim) : dim(dim) {}
247
+
248
+ void next(const constant int* shape, const constant int64_t* strides) {
249
+ index++;
250
+ if (dim > 1) {
251
+ offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
252
+ } else {
253
+ offset += OffsetT(strides[0]);
254
+ }
255
+ }
256
+
257
+ void next(int n, const constant int* shape, const constant int64_t* strides) {
258
+ index += n;
259
+ if (dim > 1) {
260
+ offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
261
+ } else {
262
+ offset = index * OffsetT(strides[0]);
263
+ }
264
+ }
265
+
266
+ OffsetT location() {
267
+ return offset;
268
+ }
269
+ };
270
+
271
+ template <typename OffsetT>
272
+ struct LoopedElemToLoc<1, OffsetT, false> {
273
+ OffsetT offset{0};
274
+
275
+ LoopedElemToLoc(int) {}
276
+
277
+ void next(const constant int*, const constant int64_t* strides) {
278
+ offset += OffsetT(strides[0]);
279
+ }
280
+
281
+ void next(int n, const constant int*, const constant int64_t* strides) {
282
+ offset += n * OffsetT(strides[0]);
283
+ }
284
+
285
+ OffsetT location() {
286
+ return offset;
287
+ }
288
+ };
289
+
290
+ ///////////////////////////////////////////////////////////////////////////////
291
+ // Calculation utils
292
+ ///////////////////////////////////////////////////////////////////////////////
293
+
294
+ /** Compute ceil((float)N/(float)M) */
295
+ template <typename T, typename U>
296
+ inline T ceildiv(T N, U M) {
297
+ return (N + M - 1) / M;
298
+ }
299
+
300
+ // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
301
+ inline float log1p(float x) {
302
+ float xp1 = 1.0f + x;
303
+ if (xp1 == Limits<float>::max) {
304
+ return Limits<float>::max;
305
+ }
306
+ if (xp1 == 1.0f) {
307
+ return x;
308
+ }
309
+
310
+ return x * (metal::log(xp1) / (xp1 - 1.0f));
311
+ }
312
+
313
+ inline bfloat16_t log1p(bfloat16_t x) {
314
+ float xp1 = 1.0f + static_cast<float>(x);
315
+ if (xp1 == Limits<float>::max) {
316
+ return Limits<bfloat16_t>::max;
317
+ }
318
+ if (xp1 == 1.0f) {
319
+ return x;
320
+ }
321
+
322
+ return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
323
+ }
324
+
325
+ ///////////////////////////////////////////////////////////////////////////////
326
+ // SIMD shuffle ops
327
+ ///////////////////////////////////////////////////////////////////////////////
328
+
329
+ inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
330
+ return as_type<uint64_t>(
331
+ metal::simd_shuffle_down(as_type<uint2>(data), delta));
332
+ }
333
+
334
+ inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
335
+ return as_type<int64_t>(
336
+ metal::simd_shuffle_down(as_type<uint2>(data), delta));
337
+ }
338
+
339
+ inline bool simd_shuffle_down(bool data, uint16_t delta) {
340
+ return simd_shuffle_down(static_cast<uint32_t>(data), delta);
341
+ }
342
+
343
+ inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
344
+ return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
345
+ }
346
+
347
+ inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
348
+ return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
349
+ }
350
+
351
+ inline bool simd_shuffle_up(bool data, uint16_t delta) {
352
+ return simd_shuffle_up(static_cast<uint32_t>(data), delta);
353
+ }
354
+
355
+ inline uint64_t
356
+ simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
357
+ return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
358
+ as_type<uint2>(data), as_type<uint2>(filling), delta));
359
+ }
360
+
361
+ inline int64_t
362
+ simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
363
+ return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
364
+ as_type<uint2>(data), as_type<uint2>(filling), delta));
365
+ }
366
+
367
+ inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
368
+ return simd_shuffle_and_fill_up(
369
+ static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
370
+ }
371
+
372
+ inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
373
+ return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
374
+ }
375
+
376
+ inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
377
+ return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
378
+ }
379
+
380
+ inline bool simd_shuffle(bool data, uint16_t lane) {
381
+ return simd_shuffle(static_cast<uint32_t>(data), lane);
382
+ }
383
+
384
+ // std::conditional is not included with Metal
385
+ template <bool condition, typename T, typename U>
386
+ struct ConditionalType {
387
+ using type = U;
388
+ };
389
+
390
+ template <typename T, typename U>
391
+ struct ConditionalType<true, T, U> {
392
+ using type = T;
393
+ };
build.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "bitsandbytes_mps"
3
+ backends = ["metal"]
4
+
5
+ [torch]
6
+ minver = "2.9"
7
+ src = [
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h",
10
+ ]
11
+
12
+ [general.hub]
13
+ repo-id = "kernels-community/bitsandbytes-mps"
14
+
15
+ [kernel.bitsandbytes_mps]
16
+
17
+ depends = ["torch"]
18
+ backend = "metal"
19
+
20
+ src = [
21
+ # Utility headers (from MLX)
22
+ "bitsandbytes_mps/bf16.h",
23
+ "bitsandbytes_mps/bf16_math.h",
24
+ "bitsandbytes_mps/complex.h",
25
+ "bitsandbytes_mps/defines.h",
26
+ "bitsandbytes_mps/utils.h",
27
+
28
+ # GEMM infrastructure (from MLX steel)
29
+ "bitsandbytes_mps/gemm/defines.h",
30
+ "bitsandbytes_mps/gemm/gemm.h",
31
+ "bitsandbytes_mps/gemm/loader.h",
32
+ "bitsandbytes_mps/gemm/mma.h",
33
+ "bitsandbytes_mps/gemm/params.h",
34
+ "bitsandbytes_mps/gemm/transforms.h",
35
+ "bitsandbytes_mps/gemm/utils.h",
36
+ "bitsandbytes_mps/gemm/utils/integral_constant.h",
37
+ "bitsandbytes_mps/gemm/utils/type_traits.h",
38
+
39
+ # Quantized matmul utilities (from MLX)
40
+ "bitsandbytes_mps/quantized_utils.h",
41
+
42
+ # BnB-specific: codebook types, kernel logic, Metal shaders, dispatch
43
+ "bitsandbytes_mps/bnb_types.h",
44
+ "bitsandbytes_mps/bnb_quantized.h",
45
+ "bitsandbytes_mps/bnb_quantized.metal",
46
+ "bitsandbytes_mps/bnb_quantized.mm",
47
+ ]
48
+
49
+ include = ["bitsandbytes_mps"]
build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9be20185521046ee042d66544cf94fa448c0e1c0455217ec81cef718d264ed9
3
+ size 845176
build/torch210-metal-aarch64-darwin/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _bitsandbytes_mps_1c65113_dirty
3
- ops = torch.ops._bitsandbytes_mps_1c65113_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_bitsandbytes_mps_1c65113_dirty::{op_name}"
 
1
  import torch
2
+ from . import _bitsandbytes_mps_9811962_dirty
3
+ ops = torch.ops._bitsandbytes_mps_9811962_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_bitsandbytes_mps_9811962_dirty::{op_name}"
build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be7a2bbf3cae711200855b297de2f3ba3d47379bf2ce52c61dd6cc3053075047
3
+ size 844504
build/torch29-metal-aarch64-darwin/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _bitsandbytes_mps_1c65113_dirty
3
- ops = torch.ops._bitsandbytes_mps_1c65113_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_bitsandbytes_mps_1c65113_dirty::{op_name}"
 
1
  import torch
2
+ from . import _bitsandbytes_mps_9811962_dirty
3
+ ops = torch.ops._bitsandbytes_mps_9811962_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_bitsandbytes_mps_9811962_dirty::{op_name}"
flake.lock ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1765121682,
6
+ "narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs"
41
+ },
42
+ "locked": {
43
+ "lastModified": 1769448133,
44
+ "narHash": "sha256-XOp8+8u7fmXn1f63mJ40dPj/OHPMKtL9o4q7y0CUZFU=",
45
+ "owner": "huggingface",
46
+ "repo": "kernel-builder",
47
+ "rev": "078351df6e0fddb4a1a41ba3ffb8b804f58c4c6a",
48
+ "type": "github"
49
+ },
50
+ "original": {
51
+ "owner": "huggingface",
52
+ "repo": "kernel-builder",
53
+ "type": "github"
54
+ }
55
+ },
56
+ "nixpkgs": {
57
+ "locked": {
58
+ "lastModified": 1766341660,
59
+ "narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
60
+ "owner": "NixOS",
61
+ "repo": "nixpkgs",
62
+ "rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
63
+ "type": "github"
64
+ },
65
+ "original": {
66
+ "owner": "NixOS",
67
+ "ref": "nixos-unstable-small",
68
+ "repo": "nixpkgs",
69
+ "type": "github"
70
+ }
71
+ },
72
+ "root": {
73
+ "inputs": {
74
+ "kernel-builder": "kernel-builder"
75
+ }
76
+ },
77
+ "systems": {
78
+ "locked": {
79
+ "lastModified": 1681028828,
80
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
81
+ "owner": "nix-systems",
82
+ "repo": "default",
83
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
84
+ "type": "github"
85
+ },
86
+ "original": {
87
+ "owner": "nix-systems",
88
+ "repo": "default",
89
+ "type": "github"
90
+ }
91
+ }
92
+ },
93
+ "root": "root",
94
+ "version": 7
95
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for triton-kernels kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
tests/__pycache__/test_bnb_mps.cpython-312-pytest-8.4.2.pyc ADDED
Binary file (18.1 kB). View file
 
tests/test_bnb_mps.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for bitsandbytes MPS 4-bit quantization kernels."""
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from bitsandbytes_mps import (
7
+ FP4,
8
+ NF4,
9
+ dequantize_4bit,
10
+ gemm_4bit,
11
+ gemv_4bit,
12
+ linear_4bit,
13
+ quantize_4bit,
14
+ )
15
+
16
+ # NF4 codebook values (matching bnb_types.h)
17
+ NF4_CODEBOOK = [
18
+ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
19
+ -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
20
+ 0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
21
+ 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
22
+ 0.7229568362236023, 1.0,
23
+ ]
24
+
25
+ FP4_CODEBOOK = [
26
+ 0.0, 0.005208333333, 0.66666667, 1.0, 0.33333333, 0.5, 0.16666667, 0.25,
27
+ 0.0, -0.005208333333, -0.66666667, -1.0, -0.33333333, -0.5, -0.16666667,
28
+ -0.25,
29
+ ]
30
+
31
+ DEVICE = "mps"
32
+
33
+
34
+ def _reference_quantize_nf4(x_flat, blocksize):
35
+ """Reference Python implementation of NF4 blockwise quantization."""
36
+ n = x_flat.numel()
37
+ num_blocks = (n + blocksize - 1) // blocksize
38
+ absmax = torch.zeros(num_blocks, dtype=torch.float32)
39
+ packed = torch.zeros((n + 1) // 2, dtype=torch.uint8)
40
+
41
+ codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32)
42
+
43
+ for b in range(num_blocks):
44
+ start = b * blocksize
45
+ end = min(start + blocksize, n)
46
+ block = x_flat[start:end].float()
47
+ am = block.abs().max().item()
48
+ absmax[b] = am
49
+
50
+ if am > 0:
51
+ normalized = (block / am).clamp(-1, 1)
52
+ else:
53
+ normalized = torch.zeros_like(block)
54
+
55
+ for i in range(0, end - start, 2):
56
+ v0 = normalized[i].item()
57
+ q0 = (codebook - v0).abs().argmin().item()
58
+
59
+ q1 = 0
60
+ if i + 1 < end - start:
61
+ v1 = normalized[i + 1].item()
62
+ q1 = (codebook - v1).abs().argmin().item()
63
+
64
+ byte_idx = (start + i) // 2
65
+ packed[byte_idx] = (q0 << 4) | (q1 & 0x0F)
66
+
67
+ return packed, absmax
68
+
69
+
70
+ def _reference_dequantize_nf4(packed, absmax, blocksize, numel):
71
+ """Reference Python implementation of NF4 blockwise dequantization."""
72
+ codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32)
73
+ output = torch.zeros(numel, dtype=torch.float32)
74
+
75
+ for i in range(numel):
76
+ byte_idx = i // 2
77
+ block_idx = i // blocksize
78
+ byte_val = packed[byte_idx].item()
79
+
80
+ if i % 2 == 0:
81
+ nibble = (byte_val >> 4) & 0x0F
82
+ else:
83
+ nibble = byte_val & 0x0F
84
+
85
+ output[i] = codebook[nibble] * absmax[block_idx].item()
86
+
87
+ return output
88
+
89
+
90
+ # ============================================================================
91
+ # Quantization / Dequantization Tests
92
+ # ============================================================================
93
+
94
+
95
+ @pytest.mark.parametrize("blocksize", [64, 128])
96
+ @pytest.mark.parametrize("quant_type", [NF4, FP4])
97
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
98
+ def test_quantize_dequantize_roundtrip(blocksize, quant_type, dtype):
99
+ """Test that quantize -> dequantize approximately recovers the original."""
100
+ torch.manual_seed(42)
101
+ n = 1024
102
+ x = torch.randn(n, dtype=dtype, device=DEVICE)
103
+
104
+ packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=quant_type)
105
+
106
+ assert packed.shape == (n // 2,)
107
+ assert packed.dtype == torch.uint8
108
+ assert absmax.dtype == torch.float32
109
+ assert absmax.shape == ((n + blocksize - 1) // blocksize,)
110
+
111
+ x_deq = dequantize_4bit(
112
+ packed, absmax, blocksize=blocksize, quant_type=quant_type,
113
+ numel=n, output_dtype=dtype,
114
+ )
115
+
116
+ assert x_deq.shape == (n,)
117
+ assert x_deq.dtype == dtype
118
+
119
+ # 4-bit quantization has significant error; check correlation
120
+ x_cpu = x.float().cpu()
121
+ x_deq_cpu = x_deq.float().cpu()
122
+ cosine_sim = torch.nn.functional.cosine_similarity(
123
+ x_cpu.unsqueeze(0), x_deq_cpu.unsqueeze(0)
124
+ ).item()
125
+ assert cosine_sim > 0.95, f"Cosine similarity too low: {cosine_sim}"
126
+
127
+
128
+ @pytest.mark.parametrize("blocksize", [64, 128])
129
+ def test_dequantize_matches_reference(blocksize):
130
+ """Test dequantization matches the Python reference implementation."""
131
+ torch.manual_seed(123)
132
+ n = 256
133
+ x = torch.randn(n, dtype=torch.float16, device=DEVICE)
134
+
135
+ packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=NF4)
136
+
137
+ # GPU dequantize
138
+ x_deq = dequantize_4bit(
139
+ packed, absmax, blocksize=blocksize, quant_type=NF4,
140
+ numel=n, output_dtype=torch.float16,
141
+ )
142
+
143
+ # Reference dequantize (on CPU)
144
+ x_ref = _reference_dequantize_nf4(
145
+ packed.cpu(), absmax.cpu(), blocksize, n
146
+ )
147
+
148
+ torch.testing.assert_close(
149
+ x_deq.float().cpu(), x_ref, rtol=1e-3, atol=1e-3
150
+ )
151
+
152
+
153
+ # ============================================================================
154
+ # GEMV Tests
155
+ # ============================================================================
156
+
157
+
158
+ @pytest.mark.parametrize("blocksize", [64, 128])
159
+ @pytest.mark.parametrize("quant_type", [NF4, FP4])
160
+ def test_gemv_correctness(blocksize, quant_type):
161
+ """Test fused GEMV against dequantize + matmul reference."""
162
+ torch.manual_seed(42)
163
+ N, K = 256, 256
164
+
165
+ # Create weight and quantize
166
+ W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
167
+ W_flat = W.flatten()
168
+ packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type)
169
+
170
+ # Reshape for GEMV
171
+ packed_w = packed.view(N, K // 2)
172
+ absmax_w = absmax.view(N, -1)
173
+
174
+ # Input vector
175
+ x = torch.randn(K, dtype=torch.float16, device=DEVICE)
176
+
177
+ # Fused GEMV
178
+ y = gemv_4bit(x, packed_w, absmax_w, output_features=N,
179
+ blocksize=blocksize, quant_type=quant_type)
180
+
181
+ # Reference: dequantize then matmul
182
+ W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize,
183
+ quant_type=quant_type, numel=N*K,
184
+ output_dtype=torch.float16)
185
+ W_deq = W_deq.view(N, K)
186
+ y_ref = W_deq @ x
187
+
188
+ # Check relative error
189
+ rel_error = (y.float() - y_ref.float()).abs().mean() / y_ref.float().abs().mean()
190
+ assert rel_error < 0.05, f"GEMV relative error too high: {rel_error}"
191
+
192
+
193
+ # ============================================================================
194
+ # GEMM Tests
195
+ # ============================================================================
196
+
197
+
198
+ @pytest.mark.parametrize("blocksize", [64, 128])
199
+ @pytest.mark.parametrize("quant_type", [NF4, FP4])
200
+ def test_gemm_correctness(blocksize, quant_type):
201
+ """Test fused GEMM against dequantize + matmul reference."""
202
+ torch.manual_seed(42)
203
+ M, N, K = 8, 128, 128
204
+
205
+ W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
206
+ W_flat = W.flatten()
207
+ packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type)
208
+
209
+ packed_w = packed.view(N, K // 2)
210
+ absmax_w = absmax.view(N, -1)
211
+
212
+ X = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
213
+
214
+ # Fused GEMM
215
+ Y = gemm_4bit(X, packed_w, absmax_w, output_features=N,
216
+ blocksize=blocksize, quant_type=quant_type)
217
+
218
+ # Reference
219
+ W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize,
220
+ quant_type=quant_type, numel=N*K,
221
+ output_dtype=torch.float16)
222
+ W_deq = W_deq.view(N, K)
223
+ Y_ref = X @ W_deq.T
224
+
225
+ rel_error = (Y.float() - Y_ref.float()).abs().mean() / Y_ref.float().abs().mean()
226
+ assert rel_error < 0.05, f"GEMM relative error too high: {rel_error}"
227
+
228
+
229
+ # ============================================================================
230
+ # Linear layer test
231
+ # ============================================================================
232
+
233
+
234
+ def test_linear_4bit_auto_select():
235
+ """Test that linear_4bit auto-selects GEMV vs GEMM."""
236
+ torch.manual_seed(42)
237
+ N, K = 128, 128
238
+
239
+ W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
240
+ packed, absmax = quantize_4bit(W.flatten(), blocksize=64, quant_type=NF4)
241
+ packed_w = packed.view(N, K // 2)
242
+ absmax_w = absmax.view(N, -1)
243
+
244
+ # Single vector - should use GEMV
245
+ x = torch.randn(K, dtype=torch.float16, device=DEVICE)
246
+ y = linear_4bit(x, packed_w, absmax_w, output_features=N)
247
+ assert y.shape == (N,)
248
+
249
+ # Batch - should use GEMM
250
+ X = torch.randn(4, K, dtype=torch.float16, device=DEVICE)
251
+ Y = linear_4bit(X, packed_w, absmax_w, output_features=N)
252
+ assert Y.shape == (4, N)
253
+
254
+
255
+ if __name__ == "__main__":
256
+ pytest.main([__file__, "-v"])
torch-ext/bitsandbytes_mps/__init__.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ # Quant type constants (match bitsandbytes DataType_t)
8
+ FP4 = 1
9
+ NF4 = 2
10
+
11
+
12
+ def quantize_4bit(
13
+ input: torch.Tensor,
14
+ blocksize: int = 64,
15
+ quant_type: int = NF4,
16
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
17
+ """Blockwise 4-bit quantization using NF4 or FP4 codebook.
18
+
19
+ Args:
20
+ input: Input tensor on MPS device (float16, bfloat16, or float32).
21
+ blocksize: Number of elements per quantization block (64 or 128).
22
+ quant_type: FP4 (1) or NF4 (2).
23
+
24
+ Returns:
25
+ Tuple of (packed, absmax):
26
+ packed: uint8 tensor of packed 4-bit values [numel/2].
27
+ absmax: float32 tensor of per-block max absolute values.
28
+ """
29
+ return ops.bnb_quantize_4bit(input, blocksize, quant_type)
30
+
31
+
32
+ def dequantize_4bit(
33
+ packed: torch.Tensor,
34
+ absmax: torch.Tensor,
35
+ blocksize: int = 64,
36
+ quant_type: int = NF4,
37
+ numel: int = -1,
38
+ output_dtype: torch.dtype = torch.float16,
39
+ ) -> torch.Tensor:
40
+ """Blockwise 4-bit dequantization using NF4 or FP4 codebook.
41
+
42
+ Args:
43
+ packed: uint8 tensor of packed 4-bit values.
44
+ absmax: float32 tensor of per-block max absolute values.
45
+ blocksize: Number of elements per quantization block (64 or 128).
46
+ quant_type: FP4 (1) or NF4 (2).
47
+ numel: Number of elements in the original tensor.
48
+ If -1, inferred as packed.numel() * 2.
49
+ output_dtype: Output scalar type.
50
+
51
+ Returns:
52
+ Dequantized tensor.
53
+ """
54
+ if numel < 0:
55
+ numel = packed.numel() * 2
56
+ return ops.bnb_dequantize_4bit(
57
+ packed, absmax, blocksize, quant_type, numel, output_dtype
58
+ )
59
+
60
+
61
+ def gemv_4bit(
62
+ x: torch.Tensor,
63
+ w: torch.Tensor,
64
+ absmax: torch.Tensor,
65
+ output_features: int,
66
+ blocksize: int = 64,
67
+ quant_type: int = NF4,
68
+ ) -> torch.Tensor:
69
+ """Fused matrix-vector multiply with 4-bit quantized weights.
70
+
71
+ Computes y = dequant(W) @ x, where W is blockwise NF4/FP4 quantized.
72
+
73
+ Args:
74
+ x: Input vector [..., K] on MPS device.
75
+ w: Packed weight matrix [N, K/2] (uint8) on MPS device.
76
+ absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
77
+ output_features: Number of output features (N).
78
+ blocksize: Quantization block size (64 or 128).
79
+ quant_type: FP4 (1) or NF4 (2).
80
+
81
+ Returns:
82
+ Output tensor [..., N].
83
+ """
84
+ return ops.bnb_gemv_4bit(x, w, absmax, blocksize, quant_type, output_features)
85
+
86
+
87
+ def gemm_4bit(
88
+ x: torch.Tensor,
89
+ w: torch.Tensor,
90
+ absmax: torch.Tensor,
91
+ output_features: int,
92
+ blocksize: int = 64,
93
+ quant_type: int = NF4,
94
+ ) -> torch.Tensor:
95
+ """Fused matrix-matrix multiply with 4-bit quantized transposed weights.
96
+
97
+ Computes Y = X @ dequant(W).T, where W is blockwise NF4/FP4 quantized.
98
+
99
+ Args:
100
+ x: Input matrix [..., M, K] on MPS device.
101
+ w: Packed weight matrix [N, K/2] (uint8) on MPS device.
102
+ absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
103
+ output_features: Number of output features (N).
104
+ blocksize: Quantization block size (64 or 128).
105
+ quant_type: FP4 (1) or NF4 (2).
106
+
107
+ Returns:
108
+ Output tensor [..., M, N].
109
+ """
110
+ return ops.bnb_gemm_4bit(x, w, absmax, blocksize, quant_type, output_features)
111
+
112
+
113
+ def linear_4bit(
114
+ x: torch.Tensor,
115
+ w: torch.Tensor,
116
+ absmax: torch.Tensor,
117
+ output_features: int,
118
+ blocksize: int = 64,
119
+ quant_type: int = NF4,
120
+ bias: Optional[torch.Tensor] = None,
121
+ ) -> torch.Tensor:
122
+ """4-bit quantized linear layer (auto-selects GEMV or GEMM).
123
+
124
+ Args:
125
+ x: Input tensor on MPS device.
126
+ w: Packed weight [N, K/2] (uint8).
127
+ absmax: Scales [N, ceil(K/blocksize)] (float32).
128
+ output_features: N.
129
+ blocksize: 64 or 128.
130
+ quant_type: FP4 (1) or NF4 (2).
131
+ bias: Optional bias [N].
132
+
133
+ Returns:
134
+ Output tensor.
135
+ """
136
+ input_1d = x.dim() == 1
137
+ if input_1d or (x.dim() >= 2 and x.size(-2) == 1):
138
+ x_flat = x.view(x.size(-1)) if input_1d else x.squeeze(-2)
139
+ y = gemv_4bit(
140
+ x_flat,
141
+ w,
142
+ absmax,
143
+ output_features,
144
+ blocksize,
145
+ quant_type,
146
+ )
147
+ if input_1d:
148
+ y = y.squeeze(0)
149
+ elif x.dim() >= 2:
150
+ y = y.unsqueeze(-2)
151
+ else:
152
+ y = gemm_4bit(x, w, absmax, output_features, blocksize, quant_type)
153
+
154
+ if bias is not None:
155
+ y = y + bias
156
+
157
+ return y
158
+
159
+ __all__ = [
160
+ "quantize_4bit",
161
+ "dequantize_4bit",
162
+ "gemv_4bit",
163
+ "gemm_4bit",
164
+ "linear_4bit",
165
+ ]
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ // 4-bit quantization
8
+ ops.def(
9
+ "bnb_quantize_4bit(Tensor input, int blocksize, int quant_type) "
10
+ "-> (Tensor, Tensor)");
11
+
12
+ // 4-bit dequantization
13
+ ops.def(
14
+ "bnb_dequantize_4bit(Tensor packed, Tensor absmax, int blocksize, "
15
+ "int quant_type, int numel, ScalarType output_dtype) -> Tensor");
16
+
17
+ // Fused GEMV with 4-bit weights
18
+ ops.def(
19
+ "bnb_gemv_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
20
+ "int quant_type, int output_features) -> Tensor");
21
+
22
+ // Fused GEMM with 4-bit transposed weights
23
+ ops.def(
24
+ "bnb_gemm_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
25
+ "int quant_type, int output_features) -> Tensor");
26
+ }
27
+
28
+ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, MPS, ops) {
29
+ ops.impl("bnb_quantize_4bit", bnb_quantize_4bit);
30
+ ops.impl("bnb_dequantize_4bit", bnb_dequantize_4bit);
31
+ ops.impl("bnb_gemv_4bit", bnb_gemv_4bit);
32
+ ops.impl("bnb_gemm_4bit", bnb_gemm_4bit);
33
+ }
34
+
35
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+ #include <tuple>
5
+
6
+ // ============================================================================
7
+ // Blockwise 4-bit quantization (NF4/FP4)
8
+ // ============================================================================
9
+
10
+ // Quantize and return both packed tensor and absmax
11
+ std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
12
+ at::Tensor input,
13
+ int64_t blocksize,
14
+ int64_t quant_type);
15
+
16
+ // ============================================================================
17
+ // Blockwise 4-bit dequantization
18
+ // ============================================================================
19
+
20
+ // Dequantize packed 4-bit tensor back to output_dtype
21
+ at::Tensor bnb_dequantize_4bit(
22
+ at::Tensor packed,
23
+ at::Tensor absmax,
24
+ int64_t blocksize,
25
+ int64_t quant_type,
26
+ int64_t numel,
27
+ c10::ScalarType output_dtype);
28
+
29
+ // ============================================================================
30
+ // Fused GEMV: y = dequant(W) @ x
31
+ // W: [N, K/2] packed, absmax: [N, K_groups], x: [..., K], y: [..., N]
32
+ // ============================================================================
33
+
34
+ at::Tensor bnb_gemv_4bit(
35
+ at::Tensor x,
36
+ at::Tensor w,
37
+ at::Tensor absmax,
38
+ int64_t blocksize,
39
+ int64_t quant_type,
40
+ int64_t output_features);
41
+
42
+ // ============================================================================
43
+ // Fused GEMM: Y = X @ dequant(W).T
44
+ // X: [M, K], W: [N, K/2] packed, absmax: [N, K_groups], Y: [M, N]
45
+ // ============================================================================
46
+
47
+ at::Tensor bnb_gemm_4bit(
48
+ at::Tensor x,
49
+ at::Tensor w,
50
+ at::Tensor absmax,
51
+ int64_t blocksize,
52
+ int64_t quant_type,
53
+ int64_t output_features);