Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- .pytest_cache/.gitignore +2 -0
- .pytest_cache/CACHEDIR.TAG +4 -0
- .pytest_cache/README.md +8 -0
- .pytest_cache/v/cache/lastfailed +1 -0
- .pytest_cache/v/cache/nodeids +21 -0
- README.md +70 -0
- bitsandbytes_mps/bf16.h +29 -0
- bitsandbytes_mps/bf16_math.h +380 -0
- bitsandbytes_mps/bnb_quantized.h +541 -0
- bitsandbytes_mps/bnb_quantized.metal +48 -0
- bitsandbytes_mps/bnb_quantized.mm +382 -0
- bitsandbytes_mps/bnb_types.h +180 -0
- bitsandbytes_mps/complex.h +173 -0
- bitsandbytes_mps/defines.h +24 -0
- bitsandbytes_mps/gemm/defines.h +5 -0
- bitsandbytes_mps/gemm/gemm.h +295 -0
- bitsandbytes_mps/gemm/loader.h +137 -0
- bitsandbytes_mps/gemm/mma.h +735 -0
- bitsandbytes_mps/gemm/params.h +64 -0
- bitsandbytes_mps/gemm/transforms.h +72 -0
- bitsandbytes_mps/gemm/utils.h +42 -0
- bitsandbytes_mps/gemm/utils/integral_constant.h +134 -0
- bitsandbytes_mps/gemm/utils/type_traits.h +55 -0
- bitsandbytes_mps/quantized_utils.h +90 -0
- bitsandbytes_mps/utils.h +393 -0
- build.toml +49 -0
- build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so +3 -0
- build/torch210-metal-aarch64-darwin/_ops.py +3 -3
- build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so +3 -0
- build/torch29-metal-aarch64-darwin/_ops.py +3 -3
- flake.lock +95 -0
- flake.nix +17 -0
- tests/__pycache__/test_bnb_mps.cpython-312-pytest-8.4.2.pyc +0 -0
- tests/test_bnb_mps.py +256 -0
- torch-ext/bitsandbytes_mps/__init__.py +165 -0
- torch-ext/torch_binding.cpp +35 -0
- torch-ext/torch_binding.h +53 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,5 @@ build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filt
|
|
| 39 |
build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 39 |
build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_1c65113_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 40 |
torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 41 |
torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so filter=lfs diff=lfs merge=lfs -text
|
.pytest_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by pytest automatically.
|
| 2 |
+
*
|
.pytest_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
| 2 |
+
# This file is a cache directory tag created by pytest.
|
| 3 |
+
# For information about cache directory tags, see:
|
| 4 |
+
# https://bford.info/cachedir/spec.html
|
.pytest_cache/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytest cache directory #
|
| 2 |
+
|
| 3 |
+
This directory contains data from the pytest's cache plugin,
|
| 4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
| 5 |
+
|
| 6 |
+
**Do not** commit this to version control.
|
| 7 |
+
|
| 8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
.pytest_cache/v/cache/lastfailed
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
.pytest_cache/v/cache/nodeids
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"tests/test_bnb_mps.py::test_dequantize_matches_reference[128]",
|
| 3 |
+
"tests/test_bnb_mps.py::test_dequantize_matches_reference[64]",
|
| 4 |
+
"tests/test_bnb_mps.py::test_gemm_correctness[1-128]",
|
| 5 |
+
"tests/test_bnb_mps.py::test_gemm_correctness[1-64]",
|
| 6 |
+
"tests/test_bnb_mps.py::test_gemm_correctness[2-128]",
|
| 7 |
+
"tests/test_bnb_mps.py::test_gemm_correctness[2-64]",
|
| 8 |
+
"tests/test_bnb_mps.py::test_gemv_correctness[1-128]",
|
| 9 |
+
"tests/test_bnb_mps.py::test_gemv_correctness[1-64]",
|
| 10 |
+
"tests/test_bnb_mps.py::test_gemv_correctness[2-128]",
|
| 11 |
+
"tests/test_bnb_mps.py::test_gemv_correctness[2-64]",
|
| 12 |
+
"tests/test_bnb_mps.py::test_linear_4bit_auto_select",
|
| 13 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-1-128]",
|
| 14 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-1-64]",
|
| 15 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-2-128]",
|
| 16 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype0-2-64]",
|
| 17 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-1-128]",
|
| 18 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-1-64]",
|
| 19 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-2-128]",
|
| 20 |
+
"tests/test_bnb_mps.py::test_quantize_dequantize_roundtrip[dtype1-2-64]"
|
| 21 |
+
]
|
README.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# bitsandbytes-mps
|
| 2 |
+
|
| 3 |
+
Metal (MPS) kernels for bitsandbytes 4-bit quantization on Apple Silicon.
|
| 4 |
+
|
| 5 |
+
Provides NF4 and FP4 blockwise quantization, dequantization, and **fused GEMV/GEMM** operations for efficient inference with 4-bit quantized models on macOS.
|
| 6 |
+
|
| 7 |
+
## Operations
|
| 8 |
+
|
| 9 |
+
| Operation | Description |
|
| 10 |
+
|-----------|-------------|
|
| 11 |
+
| `quantize_4bit` | Blockwise 4-bit quantization (NF4/FP4) with per-block absmax |
|
| 12 |
+
| `dequantize_4bit` | Blockwise 4-bit dequantization using codebook lookup |
|
| 13 |
+
| `gemv_4bit` | Fused dequantize + matrix-vector multiply (batch_size=1 inference) |
|
| 14 |
+
| `gemm_4bit` | Fused dequantize + matrix-matrix multiply (larger batch inference) |
|
| 15 |
+
| `linear_4bit` | Auto-selecting linear layer (GEMV for vectors, GEMM for matrices) |
|
| 16 |
+
|
| 17 |
+
## Quantization Format
|
| 18 |
+
|
| 19 |
+
Uses the bitsandbytes blockwise quantization scheme:
|
| 20 |
+
- **Packing**: 2 values per byte (high nibble = first element, low nibble = second)
|
| 21 |
+
- **Scaling**: One `absmax` (float32) per block of `blocksize` elements
|
| 22 |
+
- **Codebook**: NF4 (16 values optimized for normal distributions) or FP4 (sign-magnitude floating point)
|
| 23 |
+
- **Dequantization**: `value = codebook[4bit_index] * absmax`
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
```python
|
| 28 |
+
import torch
|
| 29 |
+
from bitsandbytes_mps import quantize_4bit, dequantize_4bit, gemv_4bit, gemm_4bit, NF4
|
| 30 |
+
|
| 31 |
+
# Quantize a weight matrix
|
| 32 |
+
weight = torch.randn(4096, 4096, dtype=torch.float16, device="mps")
|
| 33 |
+
packed, absmax = quantize_4bit(weight.flatten(), blocksize=64, quant_type=NF4)
|
| 34 |
+
|
| 35 |
+
# Dequantize
|
| 36 |
+
weight_deq = dequantize_4bit(packed, absmax, blocksize=64, quant_type=NF4,
|
| 37 |
+
numel=weight.numel(), output_dtype=torch.float16)
|
| 38 |
+
|
| 39 |
+
# Fused GEMV (single vector)
|
| 40 |
+
x = torch.randn(4096, dtype=torch.float16, device="mps")
|
| 41 |
+
packed_w = packed.view(4096, -1) # [N, K/2]
|
| 42 |
+
absmax_w = absmax.view(4096, -1) # [N, K_groups]
|
| 43 |
+
y = gemv_4bit(x, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)
|
| 44 |
+
|
| 45 |
+
# Fused GEMM (batch of vectors)
|
| 46 |
+
X = torch.randn(8, 4096, dtype=torch.float16, device="mps")
|
| 47 |
+
Y = gemm_4bit(X, packed_w, absmax_w, output_features=4096, blocksize=64, quant_type=NF4)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Supported Configurations
|
| 51 |
+
|
| 52 |
+
- **Scalar types**: float16, bfloat16, float32
|
| 53 |
+
- **Block sizes**: 64, 128
|
| 54 |
+
- **Quant types**: FP4, NF4
|
| 55 |
+
|
| 56 |
+
## Architecture
|
| 57 |
+
|
| 58 |
+
The kernels are adapted from [MLX quantization Metal kernels](https://github.com/ml-explore/mlx) with the following modifications:
|
| 59 |
+
|
| 60 |
+
1. **Codebook-based dequantization** replaces MLX's affine `scale * q + bias` with `codebook[q] * absmax`
|
| 61 |
+
2. **BnB packing format**: high nibble first (vs MLX's low nibble first)
|
| 62 |
+
3. **`BnBQuantizedBlockLoader`**: Custom block loader for tiled GEMM that dequantizes on-the-fly using codebook lookup
|
| 63 |
+
4. **Binary search quantization**: Efficient NF4/FP4 quantization using decision trees (matching CUDA kernels)
|
| 64 |
+
|
| 65 |
+
## Building
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
pip install kernel-builder
|
| 69 |
+
kernel-builder build .
|
| 70 |
+
```
|
bitsandbytes_mps/bf16.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
|
| 7 |
+
using namespace metal;
|
| 8 |
+
|
| 9 |
+
#if __METAL_VERSION__ >= 310
|
| 10 |
+
typedef bfloat bfloat16_t;
|
| 11 |
+
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
| 12 |
+
return as_type<uint16_t>(x);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
| 16 |
+
return as_type<bfloat16_t>(x);
|
| 17 |
+
}
|
| 18 |
+
#else
|
| 19 |
+
// bfloat not available before Metal 3.1; use a stub so the file parses
|
| 20 |
+
// but only half/float kernels will be instantiated.
|
| 21 |
+
typedef half bfloat16_t;
|
| 22 |
+
inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
|
| 23 |
+
return as_type<uint16_t>(x);
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
|
| 27 |
+
return as_type<bfloat16_t>(x);
|
| 28 |
+
}
|
| 29 |
+
#endif
|
bitsandbytes_mps/bf16_math.h
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 6 |
+
// Metal math for bfloat16
|
| 7 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
/*
|
| 10 |
+
|
| 11 |
+
Following the Metal Shading Language Specification (Metal 3.1)
|
| 12 |
+
|
| 13 |
+
"bfloat is an extended itypeing point type that only allows implicit conversion
|
| 14 |
+
to a type of greater itypeing point rank. While bfloat can be implicitly
|
| 15 |
+
converted to itype, it cannot be implicitly converted to half, and neither
|
| 16 |
+
itype nor half can be implicitly converted to bfloat."
|
| 17 |
+
|
| 18 |
+
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
| 19 |
+
for bfloat and calling with an argument of type bfloat will result in that
|
| 20 |
+
argument getting implicitly converted to itype which then returns an output
|
| 21 |
+
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
| 22 |
+
|
| 23 |
+
This leads to situations where
|
| 24 |
+
bfloat a = 5.0bf;
|
| 25 |
+
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
| 26 |
+
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
| 27 |
+
|
| 28 |
+
For the moment, I will be adding overloaded instantiations of the math
|
| 29 |
+
functions to accordingly automatically handle the casting
|
| 30 |
+
|
| 31 |
+
*/
|
| 32 |
+
|
| 33 |
+
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
| 34 |
+
\
|
| 35 |
+
METAL_FUNC otype abs(itype x) { \
|
| 36 |
+
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
| 37 |
+
} \
|
| 38 |
+
METAL_FUNC otype acos(itype x) { \
|
| 39 |
+
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
| 40 |
+
} \
|
| 41 |
+
METAL_FUNC otype acosh(itype x) { \
|
| 42 |
+
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
| 43 |
+
} \
|
| 44 |
+
METAL_FUNC otype asin(itype x) { \
|
| 45 |
+
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
| 46 |
+
} \
|
| 47 |
+
METAL_FUNC otype asinh(itype x) { \
|
| 48 |
+
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
| 49 |
+
} \
|
| 50 |
+
METAL_FUNC otype atan(itype y_over_x) { \
|
| 51 |
+
return static_cast<otype>( \
|
| 52 |
+
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
| 53 |
+
} \
|
| 54 |
+
METAL_FUNC otype atan2(itype y, itype x) { \
|
| 55 |
+
return static_cast<otype>( \
|
| 56 |
+
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
| 57 |
+
} \
|
| 58 |
+
METAL_FUNC otype atanh(itype x) { \
|
| 59 |
+
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
| 60 |
+
} \
|
| 61 |
+
METAL_FUNC otype ceil(itype x) { \
|
| 62 |
+
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
| 63 |
+
} \
|
| 64 |
+
METAL_FUNC otype cos(itype x) { \
|
| 65 |
+
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
| 66 |
+
} \
|
| 67 |
+
METAL_FUNC otype cosh(itype x) { \
|
| 68 |
+
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
| 69 |
+
} \
|
| 70 |
+
METAL_FUNC otype cospi(itype x) { \
|
| 71 |
+
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
| 72 |
+
} \
|
| 73 |
+
METAL_FUNC otype divide(itype x, itype y) { \
|
| 74 |
+
return static_cast<otype>( \
|
| 75 |
+
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 76 |
+
} \
|
| 77 |
+
METAL_FUNC otype exp(itype x) { \
|
| 78 |
+
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
| 79 |
+
} \
|
| 80 |
+
METAL_FUNC otype exp10(itype x) { \
|
| 81 |
+
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
| 82 |
+
} \
|
| 83 |
+
METAL_FUNC otype exp2(itype x) { \
|
| 84 |
+
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
| 85 |
+
} \
|
| 86 |
+
METAL_FUNC otype fabs(itype x) { \
|
| 87 |
+
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
| 88 |
+
} \
|
| 89 |
+
METAL_FUNC otype fdim(itype x, itype y) { \
|
| 90 |
+
ctype t = static_cast<ctype>(x - y); \
|
| 91 |
+
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
| 92 |
+
} \
|
| 93 |
+
METAL_FUNC otype floor(itype x) { \
|
| 94 |
+
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
| 95 |
+
} \
|
| 96 |
+
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
| 97 |
+
return static_cast<otype>(__metal_fma( \
|
| 98 |
+
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
| 99 |
+
} \
|
| 100 |
+
METAL_FUNC otype fmax(itype x, itype y) { \
|
| 101 |
+
return static_cast<otype>( \
|
| 102 |
+
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 103 |
+
} \
|
| 104 |
+
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
| 105 |
+
return static_cast<otype>(__metal_fmax3( \
|
| 106 |
+
static_cast<ctype>(x), \
|
| 107 |
+
static_cast<ctype>(y), \
|
| 108 |
+
static_cast<ctype>(z), \
|
| 109 |
+
mfast)); \
|
| 110 |
+
} \
|
| 111 |
+
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
| 112 |
+
return static_cast<otype>(__metal_fmedian3( \
|
| 113 |
+
static_cast<ctype>(x), \
|
| 114 |
+
static_cast<ctype>(y), \
|
| 115 |
+
static_cast<ctype>(z), \
|
| 116 |
+
mfast)); \
|
| 117 |
+
} \
|
| 118 |
+
METAL_FUNC otype fmin(itype x, itype y) { \
|
| 119 |
+
return static_cast<otype>( \
|
| 120 |
+
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 121 |
+
} \
|
| 122 |
+
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
| 123 |
+
return static_cast<otype>(__metal_fmin3( \
|
| 124 |
+
static_cast<ctype>(x), \
|
| 125 |
+
static_cast<ctype>(y), \
|
| 126 |
+
static_cast<ctype>(z), \
|
| 127 |
+
mfast)); \
|
| 128 |
+
} \
|
| 129 |
+
METAL_FUNC otype fmod(itype x, itype y) { \
|
| 130 |
+
return static_cast<otype>( \
|
| 131 |
+
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 132 |
+
} \
|
| 133 |
+
METAL_FUNC otype fract(itype x) { \
|
| 134 |
+
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
| 135 |
+
} \
|
| 136 |
+
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
| 137 |
+
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
| 138 |
+
} \
|
| 139 |
+
METAL_FUNC otype ldexp(itype x, int k) { \
|
| 140 |
+
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
| 141 |
+
} \
|
| 142 |
+
METAL_FUNC otype log(itype x) { \
|
| 143 |
+
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
| 144 |
+
} \
|
| 145 |
+
METAL_FUNC otype log10(itype x) { \
|
| 146 |
+
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
| 147 |
+
} \
|
| 148 |
+
METAL_FUNC otype log2(itype x) { \
|
| 149 |
+
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
| 150 |
+
} \
|
| 151 |
+
METAL_FUNC otype max(itype x, itype y) { \
|
| 152 |
+
return static_cast<otype>( \
|
| 153 |
+
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 154 |
+
} \
|
| 155 |
+
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
| 156 |
+
return static_cast<otype>(__metal_fmax3( \
|
| 157 |
+
static_cast<ctype>(x), \
|
| 158 |
+
static_cast<ctype>(y), \
|
| 159 |
+
static_cast<ctype>(z), \
|
| 160 |
+
mfast)); \
|
| 161 |
+
} \
|
| 162 |
+
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
| 163 |
+
return static_cast<otype>(__metal_fmedian3( \
|
| 164 |
+
static_cast<ctype>(x), \
|
| 165 |
+
static_cast<ctype>(y), \
|
| 166 |
+
static_cast<ctype>(z), \
|
| 167 |
+
mfast)); \
|
| 168 |
+
} \
|
| 169 |
+
METAL_FUNC otype min(itype x, itype y) { \
|
| 170 |
+
return static_cast<otype>( \
|
| 171 |
+
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 172 |
+
} \
|
| 173 |
+
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
| 174 |
+
return static_cast<otype>(__metal_fmin3( \
|
| 175 |
+
static_cast<ctype>(x), \
|
| 176 |
+
static_cast<ctype>(y), \
|
| 177 |
+
static_cast<ctype>(z), \
|
| 178 |
+
mfast)); \
|
| 179 |
+
} \
|
| 180 |
+
METAL_FUNC otype nextafter(itype x, itype y) { \
|
| 181 |
+
return static_cast<otype>( \
|
| 182 |
+
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
| 183 |
+
} \
|
| 184 |
+
METAL_FUNC otype pow(itype x, itype y) { \
|
| 185 |
+
return static_cast<otype>( \
|
| 186 |
+
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 187 |
+
} \
|
| 188 |
+
METAL_FUNC otype powr(itype x, itype y) { \
|
| 189 |
+
return static_cast<otype>( \
|
| 190 |
+
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
| 191 |
+
} \
|
| 192 |
+
METAL_FUNC otype rint(itype x) { \
|
| 193 |
+
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
| 194 |
+
} \
|
| 195 |
+
METAL_FUNC otype round(itype x) { \
|
| 196 |
+
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
| 197 |
+
} \
|
| 198 |
+
METAL_FUNC otype rsqrt(itype x) { \
|
| 199 |
+
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
| 200 |
+
} \
|
| 201 |
+
METAL_FUNC otype sin(itype x) { \
|
| 202 |
+
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
| 203 |
+
} \
|
| 204 |
+
METAL_FUNC otype sinh(itype x) { \
|
| 205 |
+
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
| 206 |
+
} \
|
| 207 |
+
METAL_FUNC otype sinpi(itype x) { \
|
| 208 |
+
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
| 209 |
+
} \
|
| 210 |
+
METAL_FUNC otype sqrt(itype x) { \
|
| 211 |
+
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
| 212 |
+
} \
|
| 213 |
+
METAL_FUNC otype tan(itype x) { \
|
| 214 |
+
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
| 215 |
+
} \
|
| 216 |
+
METAL_FUNC otype tanh(itype x) { \
|
| 217 |
+
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
| 218 |
+
} \
|
| 219 |
+
METAL_FUNC otype tanpi(itype x) { \
|
| 220 |
+
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
| 221 |
+
} \
|
| 222 |
+
METAL_FUNC otype trunc(itype x) { \
|
| 223 |
+
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
namespace metal {
|
| 227 |
+
|
| 228 |
+
instantiate_metal_math_funcs(
|
| 229 |
+
bfloat16_t,
|
| 230 |
+
bfloat16_t,
|
| 231 |
+
float,
|
| 232 |
+
__METAL_MAYBE_FAST_MATH__);
|
| 233 |
+
|
| 234 |
+
namespace fast {
|
| 235 |
+
|
| 236 |
+
instantiate_metal_math_funcs(
|
| 237 |
+
bfloat16_t,
|
| 238 |
+
bfloat16_t,
|
| 239 |
+
float,
|
| 240 |
+
__METAL_FAST_MATH__);
|
| 241 |
+
|
| 242 |
+
} // namespace fast
|
| 243 |
+
|
| 244 |
+
namespace precise {
|
| 245 |
+
|
| 246 |
+
instantiate_metal_math_funcs(
|
| 247 |
+
bfloat16_t,
|
| 248 |
+
bfloat16_t,
|
| 249 |
+
float,
|
| 250 |
+
__METAL_PRECISE_MATH__);
|
| 251 |
+
|
| 252 |
+
} // namespace precise
|
| 253 |
+
|
| 254 |
+
} // namespace metal
|
| 255 |
+
|
| 256 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 257 |
+
// Metal simd for bfloat16
|
| 258 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 259 |
+
|
| 260 |
+
#define instantiate_metal_simd_comm_funcs( \
|
| 261 |
+
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
| 262 |
+
\
|
| 263 |
+
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
| 264 |
+
return ctype_to_otype( \
|
| 265 |
+
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
| 266 |
+
} \
|
| 267 |
+
\
|
| 268 |
+
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
| 269 |
+
return ctype_to_otype( \
|
| 270 |
+
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
| 271 |
+
} \
|
| 272 |
+
\
|
| 273 |
+
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
| 274 |
+
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
| 275 |
+
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
| 276 |
+
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
| 277 |
+
} \
|
| 278 |
+
\
|
| 279 |
+
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
| 280 |
+
itype data, itype filling_data, ushort delta) { \
|
| 281 |
+
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
| 282 |
+
itype_to_ctype(data), \
|
| 283 |
+
itype_to_ctype(filling_data), \
|
| 284 |
+
delta, \
|
| 285 |
+
__metal_get_simdgroup_size(ushort()))); \
|
| 286 |
+
} \
|
| 287 |
+
\
|
| 288 |
+
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
| 289 |
+
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
| 290 |
+
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
| 291 |
+
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
| 292 |
+
} \
|
| 293 |
+
\
|
| 294 |
+
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
| 295 |
+
itype data, itype filling_data, ushort delta) { \
|
| 296 |
+
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
| 297 |
+
itype_to_ctype(data), \
|
| 298 |
+
itype_to_ctype(filling_data), \
|
| 299 |
+
delta, \
|
| 300 |
+
__metal_get_simdgroup_size(ushort()))); \
|
| 301 |
+
} \
|
| 302 |
+
\
|
| 303 |
+
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
| 304 |
+
return ctype_to_otype( \
|
| 305 |
+
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
| 306 |
+
} \
|
| 307 |
+
\
|
| 308 |
+
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
| 309 |
+
return ctype_to_otype( \
|
| 310 |
+
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
| 311 |
+
} \
|
| 312 |
+
\
|
| 313 |
+
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
| 314 |
+
return ctype_to_otype( \
|
| 315 |
+
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
| 316 |
+
} \
|
| 317 |
+
\
|
| 318 |
+
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
| 319 |
+
return ctype_to_otype( \
|
| 320 |
+
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
| 321 |
+
} \
|
| 322 |
+
\
|
| 323 |
+
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
| 324 |
+
return ctype_to_otype( \
|
| 325 |
+
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
| 329 |
+
\
|
| 330 |
+
METAL_FUNC otype simd_max(itype data) { \
|
| 331 |
+
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
| 332 |
+
} \
|
| 333 |
+
\
|
| 334 |
+
METAL_FUNC otype simd_min(itype data) { \
|
| 335 |
+
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
| 336 |
+
} \
|
| 337 |
+
\
|
| 338 |
+
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
| 339 |
+
return static_cast<otype>( \
|
| 340 |
+
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
| 341 |
+
} \
|
| 342 |
+
\
|
| 343 |
+
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
| 344 |
+
return static_cast<otype>( \
|
| 345 |
+
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
| 346 |
+
} \
|
| 347 |
+
\
|
| 348 |
+
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
| 349 |
+
return static_cast<otype>( \
|
| 350 |
+
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
| 351 |
+
} \
|
| 352 |
+
\
|
| 353 |
+
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
| 354 |
+
return static_cast<otype>( \
|
| 355 |
+
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
| 356 |
+
} \
|
| 357 |
+
\
|
| 358 |
+
METAL_FUNC otype simd_product(itype data) { \
|
| 359 |
+
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
| 360 |
+
} \
|
| 361 |
+
\
|
| 362 |
+
METAL_FUNC otype simd_sum(itype data) { \
|
| 363 |
+
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
| 364 |
+
} \
|
| 365 |
+
\
|
| 366 |
+
METAL_FUNC otype simd_xor(itype data) { \
|
| 367 |
+
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
namespace metal {
|
| 371 |
+
|
| 372 |
+
instantiate_metal_simd_comm_funcs(
|
| 373 |
+
bfloat16_t,
|
| 374 |
+
bfloat16_t,
|
| 375 |
+
uint16_t,
|
| 376 |
+
bfloat16_to_uint16,
|
| 377 |
+
uint16_to_bfloat16);
|
| 378 |
+
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
| 379 |
+
|
| 380 |
+
} // namespace metal
|
bitsandbytes_mps/bnb_quantized.h
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// bitsandbytes MPS Metal kernels - 4-bit quantized operations
|
| 2 |
+
// Adapted from MLX quantized.h for bitsandbytes NF4/FP4 format.
|
| 3 |
+
//
|
| 4 |
+
// Key differences from MLX affine quantization:
|
| 5 |
+
// MLX: dequant(q) = scale * q_int + bias (linear mapping)
|
| 6 |
+
// BnB: dequant(q) = codebook[q_int] * absmax (lookup-based)
|
| 7 |
+
//
|
| 8 |
+
// Packing format:
|
| 9 |
+
// BnB: high nibble = first element, low nibble = second element
|
| 10 |
+
// Two 4-bit values per byte, pack_factor = 2
|
| 11 |
+
|
| 12 |
+
#include <metal_simdgroup>
|
| 13 |
+
#include <metal_stdlib>
|
| 14 |
+
|
| 15 |
+
#include "bnb_types.h"
|
| 16 |
+
|
| 17 |
+
using namespace metal;
|
| 18 |
+
|
| 19 |
+
#define MLX_MTL_CONST static constant constexpr const
|
| 20 |
+
|
| 21 |
+
MLX_MTL_CONST int SIMD_SIZE = 32;
|
| 22 |
+
|
| 23 |
+
// ============================================================================
|
| 24 |
+
// BnBQuantizedBlockLoader
|
| 25 |
+
//
|
| 26 |
+
// Loads blocks of BnB 4-bit packed weights into threadgroup memory,
|
| 27 |
+
// performing codebook dequantization on the fly.
|
| 28 |
+
// Adapted from MLX QuantizedBlockLoader.
|
| 29 |
+
//
|
| 30 |
+
// Template parameters:
|
| 31 |
+
// T - output scalar type (float16_t, bfloat16_t, float)
|
| 32 |
+
// BROWS - number of rows in the tile
|
| 33 |
+
// BCOLS - number of columns in the tile (unpacked)
|
| 34 |
+
// dst_ld - leading dimension of destination (threadgroup memory)
|
| 35 |
+
// reduction_dim - 0 for K along rows, 1 for K along columns
|
| 36 |
+
// tgp_size - threads per threadgroup
|
| 37 |
+
// blocksize - BnB blocksize (elements per absmax value)
|
| 38 |
+
// quant_type - BNB_FP4 (1) or BNB_NF4 (2)
|
| 39 |
+
// ============================================================================
|
| 40 |
+
|
| 41 |
+
template <
|
| 42 |
+
typename T,
|
| 43 |
+
short BROWS,
|
| 44 |
+
short BCOLS,
|
| 45 |
+
short dst_ld,
|
| 46 |
+
short reduction_dim,
|
| 47 |
+
short tgp_size,
|
| 48 |
+
short blocksize,
|
| 49 |
+
int quant_type>
|
| 50 |
+
struct BnBQuantizedBlockLoader {
|
| 51 |
+
static_assert(
|
| 52 |
+
BCOLS <= blocksize,
|
| 53 |
+
"The blocksize should be larger than the tile columns");
|
| 54 |
+
static_assert(
|
| 55 |
+
blocksize % BCOLS == 0,
|
| 56 |
+
"The blocksize should be divisible by the tile columns");
|
| 57 |
+
|
| 58 |
+
MLX_MTL_CONST short pack_factor = 2;
|
| 59 |
+
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
| 60 |
+
MLX_MTL_CONST short n_reads =
|
| 61 |
+
(BCOLS_PACKED * BROWS < tgp_size) ? 1
|
| 62 |
+
: (BCOLS_PACKED * BROWS) / tgp_size;
|
| 63 |
+
MLX_MTL_CONST short group_steps = blocksize / BCOLS;
|
| 64 |
+
|
| 65 |
+
const int src_ld;
|
| 66 |
+
const int tile_stride;
|
| 67 |
+
short group_step_cnt;
|
| 68 |
+
const int group_stride;
|
| 69 |
+
|
| 70 |
+
const short thread_idx;
|
| 71 |
+
const short bi;
|
| 72 |
+
const short bj;
|
| 73 |
+
|
| 74 |
+
threadgroup T* dst;
|
| 75 |
+
const device uint8_t* src;
|
| 76 |
+
const device float* absmax_ptr;
|
| 77 |
+
|
| 78 |
+
BnBQuantizedBlockLoader(
|
| 79 |
+
const device uint8_t* src_,
|
| 80 |
+
const device float* absmax_,
|
| 81 |
+
const int src_ld_,
|
| 82 |
+
threadgroup T* dst_,
|
| 83 |
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 84 |
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
| 85 |
+
: src_ld(src_ld_),
|
| 86 |
+
tile_stride(
|
| 87 |
+
reduction_dim ? BCOLS_PACKED : BROWS * src_ld / pack_factor),
|
| 88 |
+
group_step_cnt(0),
|
| 89 |
+
group_stride(BROWS * src_ld / blocksize),
|
| 90 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
| 91 |
+
bi(n_reads * thread_idx / BCOLS_PACKED),
|
| 92 |
+
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
| 93 |
+
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
| 94 |
+
src(src_ + bi * src_ld / pack_factor + bj),
|
| 95 |
+
absmax_ptr(absmax_ + bi * src_ld / blocksize) {}
|
| 96 |
+
|
| 97 |
+
void load_unsafe() const {
|
| 98 |
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
| 99 |
+
return;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
float am = *absmax_ptr;
|
| 103 |
+
for (int i = 0; i < n_reads; i++) {
|
| 104 |
+
bnb_dequantize<T, pack_factor, quant_type>(src + i, T(am), dst + i * pack_factor);
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
void load_safe(short2 src_tile_dim) const {
|
| 109 |
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
| 110 |
+
return;
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
|
| 114 |
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
| 115 |
+
dst[i] = T(0);
|
| 116 |
+
}
|
| 117 |
+
return;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
|
| 121 |
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
| 122 |
+
dst[i] = T(0);
|
| 123 |
+
}
|
| 124 |
+
return;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
float am = *absmax_ptr;
|
| 128 |
+
for (int i = 0; i < n_reads; i++) {
|
| 129 |
+
bnb_dequantize<T, pack_factor, quant_type>(src + i, T(am), dst + i * pack_factor);
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
void next() {
|
| 134 |
+
src += tile_stride;
|
| 135 |
+
if (reduction_dim == 1) {
|
| 136 |
+
if (group_steps > 1) {
|
| 137 |
+
group_step_cnt++;
|
| 138 |
+
if (group_step_cnt == group_steps) {
|
| 139 |
+
group_step_cnt = 0;
|
| 140 |
+
absmax_ptr++;
|
| 141 |
+
}
|
| 142 |
+
} else {
|
| 143 |
+
absmax_ptr++;
|
| 144 |
+
}
|
| 145 |
+
} else {
|
| 146 |
+
absmax_ptr += group_stride;
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
// ============================================================================
|
| 152 |
+
// BnB GEMV (matrix-vector multiply with 4-bit quantized weights)
|
| 153 |
+
//
|
| 154 |
+
// Computes y = dequant(W) @ x
|
| 155 |
+
// W: [N, K/2] packed bytes, absmax: [N, ceil(K/blocksize)], x: [K], y: [N]
|
| 156 |
+
//
|
| 157 |
+
// Each simdgroup handles results_per_simdgroup output rows.
|
| 158 |
+
// Each thread processes values_per_thread elements of K per iteration.
|
| 159 |
+
// ============================================================================
|
| 160 |
+
|
| 161 |
+
template <typename T, int blocksize, int quant_type>
|
| 162 |
+
METAL_FUNC void bnb_qmv_impl(
|
| 163 |
+
const device uint8_t* w,
|
| 164 |
+
const device float* absmax,
|
| 165 |
+
const device T* x,
|
| 166 |
+
device T* y,
|
| 167 |
+
const constant int& in_vec_size,
|
| 168 |
+
const constant int& out_vec_size,
|
| 169 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 170 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
| 171 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 172 |
+
constexpr int num_simdgroups = 2;
|
| 173 |
+
constexpr int results_per_simdgroup = 4;
|
| 174 |
+
constexpr int bytes_per_thread = 4;
|
| 175 |
+
constexpr int values_per_thread = bytes_per_thread * 2;
|
| 176 |
+
constexpr int block_size_k = values_per_thread * SIMD_SIZE;
|
| 177 |
+
constexpr int scale_step_per_thread = blocksize / values_per_thread;
|
| 178 |
+
|
| 179 |
+
constant float* codebook = bnb_codebook<quant_type>();
|
| 180 |
+
|
| 181 |
+
typedef float U;
|
| 182 |
+
thread U x_thread[values_per_thread];
|
| 183 |
+
thread U result[results_per_simdgroup] = {0};
|
| 184 |
+
|
| 185 |
+
const int K_packed = in_vec_size / 2;
|
| 186 |
+
const int K_groups = (in_vec_size + blocksize - 1) / blocksize;
|
| 187 |
+
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
| 188 |
+
simd_gid * results_per_simdgroup;
|
| 189 |
+
|
| 190 |
+
if (out_row >= out_vec_size) {
|
| 191 |
+
return;
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
| 195 |
+
|
| 196 |
+
const device uint8_t* ws =
|
| 197 |
+
w + used_out_row * K_packed + simd_lid * bytes_per_thread;
|
| 198 |
+
const device float* am =
|
| 199 |
+
absmax + used_out_row * K_groups + simd_lid / scale_step_per_thread;
|
| 200 |
+
const device T* xi = x + tid.x * in_vec_size + simd_lid * values_per_thread;
|
| 201 |
+
y += tid.x * out_vec_size + used_out_row;
|
| 202 |
+
|
| 203 |
+
int k = 0;
|
| 204 |
+
for (; k < in_vec_size - block_size_k; k += block_size_k) {
|
| 205 |
+
// Load x values
|
| 206 |
+
for (int i = 0; i < values_per_thread; i++) {
|
| 207 |
+
x_thread[i] = U(xi[i]);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
// Compute dot product for each output row
|
| 211 |
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
| 212 |
+
const device uint8_t* wl = ws + row * K_packed;
|
| 213 |
+
U scale = U(am[row * K_groups]);
|
| 214 |
+
|
| 215 |
+
U accum = 0;
|
| 216 |
+
for (int i = 0; i < bytes_per_thread; i++) {
|
| 217 |
+
uint8_t byte_val = wl[i];
|
| 218 |
+
U w0 = U(codebook[(byte_val >> 4) & 0x0f]);
|
| 219 |
+
U w1 = U(codebook[byte_val & 0x0f]);
|
| 220 |
+
accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1;
|
| 221 |
+
}
|
| 222 |
+
result[row] += accum * scale;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
ws += block_size_k / 2;
|
| 226 |
+
am += block_size_k / blocksize;
|
| 227 |
+
xi += block_size_k;
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
// Handle remaining K elements
|
| 231 |
+
const int remaining = clamp(
|
| 232 |
+
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
| 233 |
+
0,
|
| 234 |
+
values_per_thread);
|
| 235 |
+
if (remaining > 0) {
|
| 236 |
+
for (int i = 0; i < remaining; i++) {
|
| 237 |
+
x_thread[i] = U(xi[i]);
|
| 238 |
+
}
|
| 239 |
+
for (int i = remaining; i < values_per_thread; i++) {
|
| 240 |
+
x_thread[i] = 0;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
| 244 |
+
const device uint8_t* wl = ws + row * K_packed;
|
| 245 |
+
U scale = U(am[row * K_groups]);
|
| 246 |
+
|
| 247 |
+
U accum = 0;
|
| 248 |
+
int bytes_to_read = (remaining + 1) / 2;
|
| 249 |
+
for (int i = 0; i < bytes_to_read; i++) {
|
| 250 |
+
uint8_t byte_val = wl[i];
|
| 251 |
+
U w0 = U(codebook[(byte_val >> 4) & 0x0f]);
|
| 252 |
+
U w1 = U(codebook[byte_val & 0x0f]);
|
| 253 |
+
accum += x_thread[2 * i] * w0 + x_thread[2 * i + 1] * w1;
|
| 254 |
+
}
|
| 255 |
+
result[row] += accum * scale;
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
// Reduce across SIMD lanes
|
| 260 |
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
| 261 |
+
result[row] = simd_sum(result[row]);
|
| 262 |
+
if (simd_lid == 0) {
|
| 263 |
+
y[row] = static_cast<T>(result[row]);
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// ============================================================================
|
| 269 |
+
// BnB GEMM with transposed weight (y = x @ dequant(w).T)
|
| 270 |
+
//
|
| 271 |
+
// x: [M, K], w: [N, K/2] packed, absmax: [N, ceil(K/blocksize)], y: [M, N]
|
| 272 |
+
//
|
| 273 |
+
// Uses tiled matrix multiply with BnBQuantizedBlockLoader for on-the-fly
|
| 274 |
+
// dequantization of weights during the GEMM computation.
|
| 275 |
+
// ============================================================================
|
| 276 |
+
|
| 277 |
+
template <
|
| 278 |
+
typename T,
|
| 279 |
+
const int blocksize,
|
| 280 |
+
const int quant_type,
|
| 281 |
+
const int BM = 32,
|
| 282 |
+
const int BK = 32,
|
| 283 |
+
const int BN = 32>
|
| 284 |
+
METAL_FUNC void bnb_qmm_t_impl(
|
| 285 |
+
const device uint8_t* w,
|
| 286 |
+
const device float* absmax,
|
| 287 |
+
const device T* x,
|
| 288 |
+
device T* y,
|
| 289 |
+
threadgroup T* Xs,
|
| 290 |
+
threadgroup T* Ws,
|
| 291 |
+
const constant int& K,
|
| 292 |
+
const constant int& N,
|
| 293 |
+
const constant int& M,
|
| 294 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 295 |
+
uint lid [[thread_index_in_threadgroup]],
|
| 296 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
| 297 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 298 |
+
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
| 299 |
+
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
| 300 |
+
|
| 301 |
+
(void)lid;
|
| 302 |
+
|
| 303 |
+
constexpr int WM = 2;
|
| 304 |
+
constexpr int WN = 2;
|
| 305 |
+
constexpr int pack_factor = 2;
|
| 306 |
+
|
| 307 |
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
| 308 |
+
|
| 309 |
+
using mma_t = mlx::steel::
|
| 310 |
+
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
| 311 |
+
using loader_x_t =
|
| 312 |
+
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
| 313 |
+
using loader_w_t = BnBQuantizedBlockLoader<
|
| 314 |
+
T,
|
| 315 |
+
BN,
|
| 316 |
+
BK,
|
| 317 |
+
BK_padded,
|
| 318 |
+
1,
|
| 319 |
+
WM * WN * SIMD_SIZE,
|
| 320 |
+
blocksize,
|
| 321 |
+
quant_type>;
|
| 322 |
+
|
| 323 |
+
const int K_packed = K / pack_factor;
|
| 324 |
+
const int K_groups = (K + blocksize - 1) / blocksize;
|
| 325 |
+
const int y_row = tid.y * BM;
|
| 326 |
+
const int y_col = tid.x * BN;
|
| 327 |
+
|
| 328 |
+
x += y_row * static_cast<int64_t>(K);
|
| 329 |
+
w += y_col * K_packed;
|
| 330 |
+
absmax += y_col * K_groups;
|
| 331 |
+
y += y_row * static_cast<int64_t>(N) + y_col;
|
| 332 |
+
|
| 333 |
+
const short num_els = min(BM, M - y_row);
|
| 334 |
+
const short num_outs = min(BN, N - y_col);
|
| 335 |
+
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
| 336 |
+
loader_w_t loader_w(
|
| 337 |
+
(const device uint8_t*)w, absmax, K, Ws, simd_gid, simd_lid);
|
| 338 |
+
mma_t mma_op(simd_gid, simd_lid);
|
| 339 |
+
|
| 340 |
+
if (num_els < BM) {
|
| 341 |
+
if (num_outs < BN) {
|
| 342 |
+
for (int k = 0; k < K; k += BK) {
|
| 343 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 344 |
+
loader_x.load_safe(short2(BK, num_els));
|
| 345 |
+
loader_w.load_safe(short2(BK, num_outs));
|
| 346 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 347 |
+
mma_op.mma(Xs, Ws);
|
| 348 |
+
loader_x.next();
|
| 349 |
+
loader_w.next();
|
| 350 |
+
}
|
| 351 |
+
} else {
|
| 352 |
+
for (int k = 0; k < K; k += BK) {
|
| 353 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 354 |
+
loader_x.load_safe(short2(BK, num_els));
|
| 355 |
+
loader_w.load_unsafe();
|
| 356 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 357 |
+
mma_op.mma(Xs, Ws);
|
| 358 |
+
loader_x.next();
|
| 359 |
+
loader_w.next();
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
} else {
|
| 363 |
+
if (num_outs < BN) {
|
| 364 |
+
for (int k = 0; k < K; k += BK) {
|
| 365 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 366 |
+
loader_x.load_unsafe();
|
| 367 |
+
loader_w.load_safe(short2(BK, num_outs));
|
| 368 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 369 |
+
mma_op.mma(Xs, Ws);
|
| 370 |
+
loader_x.next();
|
| 371 |
+
loader_w.next();
|
| 372 |
+
}
|
| 373 |
+
} else {
|
| 374 |
+
for (int k = 0; k < K; k += BK) {
|
| 375 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 376 |
+
loader_x.load_unsafe();
|
| 377 |
+
loader_w.load_unsafe();
|
| 378 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 379 |
+
mma_op.mma(Xs, Ws);
|
| 380 |
+
loader_x.next();
|
| 381 |
+
loader_w.next();
|
| 382 |
+
}
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
// Store results
|
| 387 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 388 |
+
if (num_els < BM || num_outs < BN) {
|
| 389 |
+
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
| 390 |
+
} else {
|
| 391 |
+
mma_op.store_result(y, N);
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
// ============================================================================
|
| 396 |
+
// Kernel entry points
|
| 397 |
+
// ============================================================================
|
| 398 |
+
|
| 399 |
+
// ---- Standalone blockwise quantize ----
|
| 400 |
+
// Each thread handles one block of elements.
|
| 401 |
+
|
| 402 |
+
template <typename T, int blocksize, int quant_type>
|
| 403 |
+
[[kernel]] void bnb_quantize_blockwise(
|
| 404 |
+
const device T* input [[buffer(0)]],
|
| 405 |
+
device float* absmax [[buffer(1)]],
|
| 406 |
+
device uint8_t* packed [[buffer(2)]],
|
| 407 |
+
const constant int& n [[buffer(3)]],
|
| 408 |
+
uint gid [[thread_position_in_grid]]) {
|
| 409 |
+
const int num_blocks = (n + blocksize - 1) / blocksize;
|
| 410 |
+
if (static_cast<int>(gid) >= num_blocks) {
|
| 411 |
+
return;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
int block_start = gid * blocksize;
|
| 415 |
+
int block_end = min(block_start + blocksize, n);
|
| 416 |
+
|
| 417 |
+
// Find absmax for this block
|
| 418 |
+
float max_val = 0.0f;
|
| 419 |
+
for (int i = block_start; i < block_end; i++) {
|
| 420 |
+
float current = metal::abs(float(input[i]));
|
| 421 |
+
max_val = metal::max(max_val, current);
|
| 422 |
+
}
|
| 423 |
+
absmax[gid] = max_val;
|
| 424 |
+
|
| 425 |
+
float inv = (max_val > 0.0f) ? 1.0f / max_val : 0.0f;
|
| 426 |
+
|
| 427 |
+
// Quantize and pack pairs of values
|
| 428 |
+
int out_byte = block_start / 2;
|
| 429 |
+
for (int i = block_start; i < block_end; i += 2) {
|
| 430 |
+
float norm0 = (max_val > 0.0f) ? clamp(float(input[i]) * inv, -1.0f, 1.0f)
|
| 431 |
+
: 0.0f;
|
| 432 |
+
uchar q0 = bnb_quantize_value<quant_type>(norm0);
|
| 433 |
+
|
| 434 |
+
uchar q1 = 0;
|
| 435 |
+
if (i + 1 < block_end) {
|
| 436 |
+
float norm1 = (max_val > 0.0f)
|
| 437 |
+
? clamp(float(input[i + 1]) * inv, -1.0f, 1.0f)
|
| 438 |
+
: 0.0f;
|
| 439 |
+
q1 = bnb_quantize_value<quant_type>(norm1);
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
packed[out_byte++] = (q0 << 4) | (q1 & 0x0f);
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// ---- Standalone blockwise dequantize ----
|
| 447 |
+
// Each threadgroup handles one block. Threads within share the absmax.
|
| 448 |
+
|
| 449 |
+
template <typename T, int blocksize, int quant_type>
|
| 450 |
+
[[kernel]] void bnb_dequantize_blockwise(
|
| 451 |
+
const device uint8_t* packed [[buffer(0)]],
|
| 452 |
+
const device float* absmax [[buffer(1)]],
|
| 453 |
+
device T* output [[buffer(2)]],
|
| 454 |
+
const constant int& n [[buffer(3)]],
|
| 455 |
+
uint tgid [[threadgroup_position_in_grid]],
|
| 456 |
+
uint tid [[thread_index_in_threadgroup]],
|
| 457 |
+
uint tg_size [[threads_per_threadgroup]]) {
|
| 458 |
+
const int num_blocks = (n + blocksize - 1) / blocksize;
|
| 459 |
+
if (static_cast<int>(tgid) >= num_blocks) {
|
| 460 |
+
return;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
constant float* codebook = bnb_codebook<quant_type>();
|
| 464 |
+
|
| 465 |
+
int block_start = tgid * blocksize;
|
| 466 |
+
int block_end = min(block_start + blocksize, n);
|
| 467 |
+
|
| 468 |
+
threadgroup float shared_scale = 0.0f;
|
| 469 |
+
if (tid == 0) {
|
| 470 |
+
shared_scale = absmax[tgid];
|
| 471 |
+
}
|
| 472 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 473 |
+
float scale = shared_scale;
|
| 474 |
+
|
| 475 |
+
int pairs_in_block = (block_end - block_start + 1) / 2;
|
| 476 |
+
|
| 477 |
+
for (int pair = static_cast<int>(tid); pair < pairs_in_block;
|
| 478 |
+
pair += static_cast<int>(tg_size)) {
|
| 479 |
+
int elem_idx = block_start + pair * 2;
|
| 480 |
+
int byte_idx = elem_idx / 2;
|
| 481 |
+
uint8_t byte_val = packed[byte_idx];
|
| 482 |
+
|
| 483 |
+
uint8_t high = (byte_val >> 4) & 0x0f;
|
| 484 |
+
uint8_t low = byte_val & 0x0f;
|
| 485 |
+
|
| 486 |
+
output[elem_idx] = T(codebook[high] * scale);
|
| 487 |
+
if (elem_idx + 1 < block_end) {
|
| 488 |
+
output[elem_idx + 1] = T(codebook[low] * scale);
|
| 489 |
+
}
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
// ---- GEMV kernel entry point ----
|
| 494 |
+
// y = dequant(W) @ x
|
| 495 |
+
// W: [N, K/2], absmax: [N, K_groups], x: [K], y: [N]
|
| 496 |
+
|
| 497 |
+
template <typename T, int blocksize, int quant_type>
|
| 498 |
+
[[kernel]] void bnb_qmv(
|
| 499 |
+
const device uint8_t* w [[buffer(0)]],
|
| 500 |
+
const device float* absmax [[buffer(1)]],
|
| 501 |
+
const device T* x [[buffer(2)]],
|
| 502 |
+
device T* y [[buffer(3)]],
|
| 503 |
+
const constant int& in_vec_size [[buffer(4)]],
|
| 504 |
+
const constant int& out_vec_size [[buffer(5)]],
|
| 505 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 506 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
| 507 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 508 |
+
bnb_qmv_impl<T, blocksize, quant_type>(
|
| 509 |
+
w, absmax, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
// ---- GEMM (transposed weight) kernel entry point ----
|
| 513 |
+
// Y = X @ dequant(W).T
|
| 514 |
+
// X: [M, K], W: [N, K/2], absmax: [N, K_groups], Y: [M, N]
|
| 515 |
+
|
| 516 |
+
template <typename T, int blocksize, int quant_type>
|
| 517 |
+
[[kernel]] void bnb_qmm_t(
|
| 518 |
+
const device uint8_t* w [[buffer(0)]],
|
| 519 |
+
const device float* absmax [[buffer(1)]],
|
| 520 |
+
const device T* x [[buffer(2)]],
|
| 521 |
+
device T* y [[buffer(3)]],
|
| 522 |
+
const constant int& K [[buffer(4)]],
|
| 523 |
+
const constant int& N [[buffer(5)]],
|
| 524 |
+
const constant int& M [[buffer(6)]],
|
| 525 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 526 |
+
uint lid [[thread_index_in_threadgroup]],
|
| 527 |
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
| 528 |
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
| 529 |
+
(void)lid;
|
| 530 |
+
|
| 531 |
+
constexpr int BM = 32;
|
| 532 |
+
constexpr int BK = 32;
|
| 533 |
+
constexpr int BN = 32;
|
| 534 |
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
| 535 |
+
|
| 536 |
+
threadgroup T Xs[BM * BK_padded];
|
| 537 |
+
threadgroup T Ws[BN * BK_padded];
|
| 538 |
+
|
| 539 |
+
bnb_qmm_t_impl<T, blocksize, quant_type, BM, BK, BN>(
|
| 540 |
+
w, absmax, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
| 541 |
+
}
|
bitsandbytes_mps/bnb_quantized.metal
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// bitsandbytes MPS Metal kernels - template instantiations
|
| 2 |
+
// Instantiates kernel variants for all (type, blocksize, quant_type) combos.
|
| 3 |
+
|
| 4 |
+
// clang-format off
|
| 5 |
+
#include "utils.h"
|
| 6 |
+
#include "gemm/gemm.h"
|
| 7 |
+
#include "quantized_utils.h"
|
| 8 |
+
#include "bnb_quantized.h"
|
| 9 |
+
|
| 10 |
+
// ============================================================================
|
| 11 |
+
// Instantiation macros
|
| 12 |
+
// ============================================================================
|
| 13 |
+
|
| 14 |
+
#define instantiate_bnb_kernel(name, type, blocksize, quant_type) \
|
| 15 |
+
template [[host_name( \
|
| 16 |
+
#name "_" #type "_bs_" #blocksize "_qt_" #quant_type \
|
| 17 |
+
)]] [[kernel]] decltype(name<type, blocksize, quant_type>) \
|
| 18 |
+
name<type, blocksize, quant_type>;
|
| 19 |
+
|
| 20 |
+
// ---- Instantiate all kernel types for a given (type, blocksize, quant_type) ----
|
| 21 |
+
|
| 22 |
+
#define instantiate_bnb_all_kernels(type, blocksize, quant_type) \
|
| 23 |
+
instantiate_bnb_kernel(bnb_quantize_blockwise, type, blocksize, quant_type) \
|
| 24 |
+
instantiate_bnb_kernel(bnb_dequantize_blockwise, type, blocksize, quant_type) \
|
| 25 |
+
instantiate_bnb_kernel(bnb_qmv, type, blocksize, quant_type) \
|
| 26 |
+
instantiate_bnb_kernel(bnb_qmm_t, type, blocksize, quant_type)
|
| 27 |
+
|
| 28 |
+
// ---- Instantiate for all quant types (FP4=1, NF4=2) ----
|
| 29 |
+
|
| 30 |
+
#define instantiate_bnb_quant_types(type, blocksize) \
|
| 31 |
+
instantiate_bnb_all_kernels(type, blocksize, 1) \
|
| 32 |
+
instantiate_bnb_all_kernels(type, blocksize, 2)
|
| 33 |
+
|
| 34 |
+
// ---- Instantiate for all blocksizes ----
|
| 35 |
+
|
| 36 |
+
#define instantiate_bnb_blocksizes(type) \
|
| 37 |
+
instantiate_bnb_quant_types(type, 64) \
|
| 38 |
+
instantiate_bnb_quant_types(type, 128) \
|
| 39 |
+
instantiate_bnb_quant_types(type, 256) \
|
| 40 |
+
instantiate_bnb_quant_types(type, 512)
|
| 41 |
+
|
| 42 |
+
// ---- Instantiate for all scalar types ----
|
| 43 |
+
|
| 44 |
+
instantiate_bnb_blocksizes(half)
|
| 45 |
+
instantiate_bnb_blocksizes(bfloat16_t)
|
| 46 |
+
instantiate_bnb_blocksizes(float)
|
| 47 |
+
|
| 48 |
+
// clang-format on
|
bitsandbytes_mps/bnb_quantized.mm
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// bitsandbytes MPS Metal kernels - ObjC++ dispatch
|
| 2 |
+
// Interfaces between PyTorch MPS tensors and Metal compute kernels.
|
| 3 |
+
// Uses the same dispatch pattern as kernels-community/activation, with
|
| 4 |
+
// get_command_buffer() moved inside dispatch_sync to avoid race conditions
|
| 5 |
+
// during model loading.
|
| 6 |
+
|
| 7 |
+
#include <torch/torch.h>
|
| 8 |
+
|
| 9 |
+
#import <Foundation/Foundation.h>
|
| 10 |
+
#import <Metal/Metal.h>
|
| 11 |
+
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#include <iostream>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include <unordered_map>
|
| 16 |
+
|
| 17 |
+
#ifdef EMBEDDED_METALLIB_HEADER
|
| 18 |
+
#include EMBEDDED_METALLIB_HEADER
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
// ============================================================================
|
| 22 |
+
// Metal helpers
|
| 23 |
+
// ============================================================================
|
| 24 |
+
|
| 25 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& t) {
|
| 26 |
+
return __builtin_bit_cast(id<MTLBuffer>, t.storage().data());
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
namespace {
|
| 30 |
+
|
| 31 |
+
static id<MTLLibrary> library = nil;
|
| 32 |
+
|
| 33 |
+
id<MTLLibrary> get_library() {
|
| 34 |
+
if (library != nil)
|
| 35 |
+
return library;
|
| 36 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 37 |
+
NSError* error = nil;
|
| 38 |
+
|
| 39 |
+
#ifdef EMBEDDED_METALLIB_HEADER
|
| 40 |
+
library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
|
| 41 |
+
if (library == nil) {
|
| 42 |
+
std::cerr << "Failed to create Metal library from embedded header"
|
| 43 |
+
<< std::endl;
|
| 44 |
+
if (error)
|
| 45 |
+
std::cerr << "Error: " << [[error localizedDescription] UTF8String]
|
| 46 |
+
<< std::endl;
|
| 47 |
+
}
|
| 48 |
+
#else
|
| 49 |
+
library = [device newDefaultLibrary];
|
| 50 |
+
if (library == nil) {
|
| 51 |
+
std::cerr << "Failed to load Metal library" << std::endl;
|
| 52 |
+
if (error)
|
| 53 |
+
std::cerr << "Error: " << [[error localizedDescription] UTF8String]
|
| 54 |
+
<< std::endl;
|
| 55 |
+
}
|
| 56 |
+
#endif
|
| 57 |
+
return library;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
id<MTLComputePipelineState> get_pipeline(const std::string& name) {
|
| 61 |
+
static std::unordered_map<std::string, id<MTLComputePipelineState>> cache;
|
| 62 |
+
auto it = cache.find(name);
|
| 63 |
+
if (it != cache.end())
|
| 64 |
+
return it->second;
|
| 65 |
+
|
| 66 |
+
id<MTLLibrary> lib = get_library();
|
| 67 |
+
if (!lib)
|
| 68 |
+
return nil;
|
| 69 |
+
|
| 70 |
+
id<MTLFunction> func =
|
| 71 |
+
[lib newFunctionWithName:[NSString stringWithUTF8String:name.c_str()]];
|
| 72 |
+
if (!func) {
|
| 73 |
+
std::cerr << "Kernel not found: " << name << std::endl;
|
| 74 |
+
return nil;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
NSError* error = nil;
|
| 78 |
+
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
| 79 |
+
id<MTLComputePipelineState> state =
|
| 80 |
+
[device newComputePipelineStateWithFunction:func error:&error];
|
| 81 |
+
if (!state) {
|
| 82 |
+
std::cerr << "Failed to create pipeline for " << name << std::endl;
|
| 83 |
+
return nil;
|
| 84 |
+
}
|
| 85 |
+
cache[name] = state;
|
| 86 |
+
return state;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
std::string type_str(torch::ScalarType type) {
|
| 90 |
+
switch (type) {
|
| 91 |
+
case torch::kFloat32:
|
| 92 |
+
return "float";
|
| 93 |
+
case torch::kFloat16:
|
| 94 |
+
return "half";
|
| 95 |
+
case torch::kBFloat16:
|
| 96 |
+
return "bfloat16_t";
|
| 97 |
+
default:
|
| 98 |
+
throw std::runtime_error("Unsupported dtype for BnB MPS kernels");
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
void set_tensor(
|
| 103 |
+
id<MTLComputeCommandEncoder> enc,
|
| 104 |
+
const torch::Tensor& t,
|
| 105 |
+
int index) {
|
| 106 |
+
[enc setBuffer:getMTLBufferStorage(t)
|
| 107 |
+
offset:t.storage_offset() * t.element_size()
|
| 108 |
+
atIndex:index];
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
} // namespace
|
| 112 |
+
|
| 113 |
+
// ============================================================================
|
| 114 |
+
// Public API: quantize_4bit
|
| 115 |
+
// ============================================================================
|
| 116 |
+
|
| 117 |
+
std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
|
| 118 |
+
at::Tensor input,
|
| 119 |
+
int64_t blocksize,
|
| 120 |
+
int64_t quant_type) {
|
| 121 |
+
TORCH_CHECK(input.is_mps(), "Input must be on MPS device");
|
| 122 |
+
TORCH_CHECK(
|
| 123 |
+
blocksize == 64 || blocksize == 128,
|
| 124 |
+
"Only blocksize 64 and 128 are supported");
|
| 125 |
+
TORCH_CHECK(
|
| 126 |
+
quant_type == 1 || quant_type == 2,
|
| 127 |
+
"quant_type must be 1 (FP4) or 2 (NF4)");
|
| 128 |
+
|
| 129 |
+
int n = static_cast<int>(input.numel());
|
| 130 |
+
int num_blocks =
|
| 131 |
+
(n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);
|
| 132 |
+
int packed_size = (n + 1) / 2;
|
| 133 |
+
|
| 134 |
+
auto absmax =
|
| 135 |
+
torch::empty({num_blocks}, input.options().dtype(torch::kFloat32));
|
| 136 |
+
auto packed =
|
| 137 |
+
torch::empty({packed_size}, input.options().dtype(torch::kUInt8));
|
| 138 |
+
|
| 139 |
+
std::stringstream ss;
|
| 140 |
+
ss << "bnb_quantize_blockwise_" << type_str(input.scalar_type()) << "_bs_"
|
| 141 |
+
<< blocksize << "_qt_" << quant_type;
|
| 142 |
+
|
| 143 |
+
auto pipeline = get_pipeline(ss.str());
|
| 144 |
+
TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
|
| 145 |
+
|
| 146 |
+
@autoreleasepool {
|
| 147 |
+
dispatch_sync(torch::mps::get_dispatch_queue(), ^{
|
| 148 |
+
@autoreleasepool {
|
| 149 |
+
id<MTLCommandBuffer> commandBuffer =
|
| 150 |
+
torch::mps::get_command_buffer();
|
| 151 |
+
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 152 |
+
|
| 153 |
+
id<MTLComputeCommandEncoder> encoder =
|
| 154 |
+
[commandBuffer computeCommandEncoder];
|
| 155 |
+
TORCH_CHECK(encoder, "Failed to create compute encoder");
|
| 156 |
+
|
| 157 |
+
[encoder setComputePipelineState:pipeline];
|
| 158 |
+
|
| 159 |
+
int idx = 0;
|
| 160 |
+
set_tensor(encoder, input, idx++);
|
| 161 |
+
set_tensor(encoder, absmax, idx++);
|
| 162 |
+
set_tensor(encoder, packed, idx++);
|
| 163 |
+
[encoder setBytes:&n length:sizeof(int) atIndex:idx++];
|
| 164 |
+
|
| 165 |
+
NSUInteger threads_per_tg = pipeline.threadExecutionWidth;
|
| 166 |
+
MTLSize grid = MTLSizeMake(num_blocks, 1, 1);
|
| 167 |
+
MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
|
| 168 |
+
[encoder dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 169 |
+
[encoder endEncoding];
|
| 170 |
+
|
| 171 |
+
torch::mps::commit();
|
| 172 |
+
}
|
| 173 |
+
});
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
return std::make_tuple(packed, absmax);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
// ============================================================================
|
| 180 |
+
// Public API: dequantize_blockwise
|
| 181 |
+
// ============================================================================
|
| 182 |
+
|
| 183 |
+
at::Tensor bnb_dequantize_4bit(
|
| 184 |
+
at::Tensor packed,
|
| 185 |
+
at::Tensor absmax,
|
| 186 |
+
int64_t blocksize,
|
| 187 |
+
int64_t quant_type,
|
| 188 |
+
int64_t numel,
|
| 189 |
+
torch::ScalarType output_dtype) {
|
| 190 |
+
TORCH_CHECK(packed.is_mps(), "packed must be on MPS device");
|
| 191 |
+
TORCH_CHECK(absmax.is_mps(), "absmax must be on MPS device");
|
| 192 |
+
TORCH_CHECK(
|
| 193 |
+
blocksize == 64 || blocksize == 128,
|
| 194 |
+
"Only blocksize 64 and 128 are supported");
|
| 195 |
+
|
| 196 |
+
int n = static_cast<int>(numel);
|
| 197 |
+
int num_blocks =
|
| 198 |
+
(n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);
|
| 199 |
+
|
| 200 |
+
auto output = torch::empty({n}, packed.options().dtype(output_dtype));
|
| 201 |
+
|
| 202 |
+
std::stringstream ss;
|
| 203 |
+
ss << "bnb_dequantize_blockwise_" << type_str(output_dtype) << "_bs_"
|
| 204 |
+
<< blocksize << "_qt_" << quant_type;
|
| 205 |
+
|
| 206 |
+
auto pipeline = get_pipeline(ss.str());
|
| 207 |
+
TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
|
| 208 |
+
|
| 209 |
+
@autoreleasepool {
|
| 210 |
+
dispatch_sync(torch::mps::get_dispatch_queue(), ^{
|
| 211 |
+
@autoreleasepool {
|
| 212 |
+
id<MTLCommandBuffer> commandBuffer =
|
| 213 |
+
torch::mps::get_command_buffer();
|
| 214 |
+
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 215 |
+
|
| 216 |
+
id<MTLComputeCommandEncoder> encoder =
|
| 217 |
+
[commandBuffer computeCommandEncoder];
|
| 218 |
+
TORCH_CHECK(encoder, "Failed to create compute encoder");
|
| 219 |
+
|
| 220 |
+
[encoder setComputePipelineState:pipeline];
|
| 221 |
+
|
| 222 |
+
int idx = 0;
|
| 223 |
+
set_tensor(encoder, packed, idx++);
|
| 224 |
+
set_tensor(encoder, absmax, idx++);
|
| 225 |
+
set_tensor(encoder, output, idx++);
|
| 226 |
+
[encoder setBytes:&n length:sizeof(int) atIndex:idx++];
|
| 227 |
+
|
| 228 |
+
NSUInteger max_tg = pipeline.maxTotalThreadsPerThreadgroup;
|
| 229 |
+
NSUInteger desired = (blocksize + 1) / 2;
|
| 230 |
+
NSUInteger tg_size =
|
| 231 |
+
std::min(max_tg, std::max(static_cast<NSUInteger>(1), desired));
|
| 232 |
+
if (tg_size < pipeline.threadExecutionWidth) {
|
| 233 |
+
tg_size = std::min(pipeline.threadExecutionWidth, max_tg);
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
MTLSize grid = MTLSizeMake(tg_size * num_blocks, 1, 1);
|
| 237 |
+
MTLSize tg = MTLSizeMake(tg_size, 1, 1);
|
| 238 |
+
[encoder dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 239 |
+
[encoder endEncoding];
|
| 240 |
+
|
| 241 |
+
torch::mps::commit();
|
| 242 |
+
}
|
| 243 |
+
});
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
return output;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
// ============================================================================
|
| 250 |
+
// Public API: GEMV (matrix-vector multiply)
|
| 251 |
+
// y = dequant(W) @ x
|
| 252 |
+
// ============================================================================
|
| 253 |
+
|
| 254 |
+
at::Tensor bnb_gemv_4bit(
|
| 255 |
+
at::Tensor x,
|
| 256 |
+
at::Tensor w,
|
| 257 |
+
at::Tensor absmax,
|
| 258 |
+
int64_t blocksize,
|
| 259 |
+
int64_t quant_type,
|
| 260 |
+
int64_t output_features) {
|
| 261 |
+
TORCH_CHECK(
|
| 262 |
+
x.is_mps() && w.is_mps() && absmax.is_mps(),
|
| 263 |
+
"All tensors must be on MPS device");
|
| 264 |
+
|
| 265 |
+
int K = static_cast<int>(x.size(-1));
|
| 266 |
+
int N = static_cast<int>(output_features);
|
| 267 |
+
|
| 268 |
+
auto out_sizes = x.sizes().vec();
|
| 269 |
+
out_sizes.back() = N;
|
| 270 |
+
auto y = torch::zeros(out_sizes, x.options());
|
| 271 |
+
|
| 272 |
+
std::stringstream ss;
|
| 273 |
+
ss << "bnb_qmv_" << type_str(x.scalar_type()) << "_bs_" << blocksize
|
| 274 |
+
<< "_qt_" << quant_type;
|
| 275 |
+
|
| 276 |
+
auto pipeline = get_pipeline(ss.str());
|
| 277 |
+
TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
|
| 278 |
+
|
| 279 |
+
@autoreleasepool {
|
| 280 |
+
dispatch_sync(torch::mps::get_dispatch_queue(), ^{
|
| 281 |
+
@autoreleasepool {
|
| 282 |
+
id<MTLCommandBuffer> commandBuffer =
|
| 283 |
+
torch::mps::get_command_buffer();
|
| 284 |
+
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 285 |
+
|
| 286 |
+
id<MTLComputeCommandEncoder> encoder =
|
| 287 |
+
[commandBuffer computeCommandEncoder];
|
| 288 |
+
TORCH_CHECK(encoder, "Failed to create compute encoder");
|
| 289 |
+
|
| 290 |
+
[encoder setComputePipelineState:pipeline];
|
| 291 |
+
|
| 292 |
+
int idx = 0;
|
| 293 |
+
set_tensor(encoder, w, idx++);
|
| 294 |
+
set_tensor(encoder, absmax, idx++);
|
| 295 |
+
set_tensor(encoder, x, idx++);
|
| 296 |
+
set_tensor(encoder, y, idx++);
|
| 297 |
+
[encoder setBytes:&K length:sizeof(int) atIndex:idx++];
|
| 298 |
+
[encoder setBytes:&N length:sizeof(int) atIndex:idx++];
|
| 299 |
+
|
| 300 |
+
int rows_per_tg = 8;
|
| 301 |
+
int grid_y = (N + rows_per_tg - 1) / rows_per_tg;
|
| 302 |
+
|
| 303 |
+
[encoder dispatchThreadgroups:MTLSizeMake(1, grid_y, 1)
|
| 304 |
+
threadsPerThreadgroup:MTLSizeMake(32 * 2, 1, 1)];
|
| 305 |
+
[encoder endEncoding];
|
| 306 |
+
|
| 307 |
+
torch::mps::commit();
|
| 308 |
+
}
|
| 309 |
+
});
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
return y;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// ============================================================================
|
| 316 |
+
// Public API: GEMM (matrix-matrix multiply with transposed weight)
|
| 317 |
+
// Y = X @ dequant(W).T
|
| 318 |
+
// ============================================================================
|
| 319 |
+
|
| 320 |
+
at::Tensor bnb_gemm_4bit(
|
| 321 |
+
at::Tensor x,
|
| 322 |
+
at::Tensor w,
|
| 323 |
+
at::Tensor absmax,
|
| 324 |
+
int64_t blocksize,
|
| 325 |
+
int64_t quant_type,
|
| 326 |
+
int64_t output_features) {
|
| 327 |
+
TORCH_CHECK(
|
| 328 |
+
x.is_mps() && w.is_mps() && absmax.is_mps(),
|
| 329 |
+
"All tensors must be on MPS device");
|
| 330 |
+
TORCH_CHECK(x.dim() >= 2, "Input must be at least 2D for GEMM");
|
| 331 |
+
|
| 332 |
+
int K = static_cast<int>(x.size(-1));
|
| 333 |
+
int M = static_cast<int>(x.size(-2));
|
| 334 |
+
int N = static_cast<int>(output_features);
|
| 335 |
+
|
| 336 |
+
auto out_sizes = x.sizes().vec();
|
| 337 |
+
out_sizes.back() = N;
|
| 338 |
+
auto y = torch::zeros(out_sizes, x.options());
|
| 339 |
+
|
| 340 |
+
std::stringstream ss;
|
| 341 |
+
ss << "bnb_qmm_t_" << type_str(x.scalar_type()) << "_bs_" << blocksize
|
| 342 |
+
<< "_qt_" << quant_type;
|
| 343 |
+
|
| 344 |
+
auto pipeline = get_pipeline(ss.str());
|
| 345 |
+
TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());
|
| 346 |
+
|
| 347 |
+
@autoreleasepool {
|
| 348 |
+
dispatch_sync(torch::mps::get_dispatch_queue(), ^{
|
| 349 |
+
@autoreleasepool {
|
| 350 |
+
id<MTLCommandBuffer> commandBuffer =
|
| 351 |
+
torch::mps::get_command_buffer();
|
| 352 |
+
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 353 |
+
|
| 354 |
+
id<MTLComputeCommandEncoder> encoder =
|
| 355 |
+
[commandBuffer computeCommandEncoder];
|
| 356 |
+
TORCH_CHECK(encoder, "Failed to create compute encoder");
|
| 357 |
+
|
| 358 |
+
[encoder setComputePipelineState:pipeline];
|
| 359 |
+
|
| 360 |
+
int idx = 0;
|
| 361 |
+
set_tensor(encoder, w, idx++);
|
| 362 |
+
set_tensor(encoder, absmax, idx++);
|
| 363 |
+
set_tensor(encoder, x, idx++);
|
| 364 |
+
set_tensor(encoder, y, idx++);
|
| 365 |
+
[encoder setBytes:&K length:sizeof(int) atIndex:idx++];
|
| 366 |
+
[encoder setBytes:&N length:sizeof(int) atIndex:idx++];
|
| 367 |
+
[encoder setBytes:&M length:sizeof(int) atIndex:idx++];
|
| 368 |
+
|
| 369 |
+
int grid_x = (N + 31) / 32;
|
| 370 |
+
int grid_y = (M + 31) / 32;
|
| 371 |
+
|
| 372 |
+
[encoder dispatchThreadgroups:MTLSizeMake(grid_x, grid_y, 1)
|
| 373 |
+
threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 374 |
+
[encoder endEncoding];
|
| 375 |
+
|
| 376 |
+
torch::mps::commit();
|
| 377 |
+
}
|
| 378 |
+
});
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
return y;
|
| 382 |
+
}
|
bitsandbytes_mps/bnb_types.h
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// bitsandbytes MPS Metal kernels - NF4/FP4 codebook definitions and helpers
|
| 2 |
+
// Adapted from bitsandbytes CUDA kernels (kernels.cu) for Apple Metal
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
#include <metal_stdlib>
|
| 7 |
+
using namespace metal;
|
| 8 |
+
|
| 9 |
+
// ============================================================================
|
| 10 |
+
// Quant type enum (matches bitsandbytes common.h)
|
| 11 |
+
// ============================================================================
|
| 12 |
+
|
| 13 |
+
enum BnBQuantType {
|
| 14 |
+
BNB_FP4 = 1,
|
| 15 |
+
BNB_NF4 = 2,
|
| 16 |
+
};
|
| 17 |
+
|
| 18 |
+
// ============================================================================
|
| 19 |
+
// NF4 codebook - 16 values optimized for normal distributions
|
| 20 |
+
// Maps 4-bit indices (0-15) to float values in [-1, 1]
|
| 21 |
+
// ============================================================================
|
| 22 |
+
|
| 23 |
+
constant float NF4_CODEBOOK[16] = {
|
| 24 |
+
-1.0f,
|
| 25 |
+
-0.6961928009986877f,
|
| 26 |
+
-0.5250730514526367f,
|
| 27 |
+
-0.39491748809814453f,
|
| 28 |
+
-0.28444138169288635f,
|
| 29 |
+
-0.18477343022823334f,
|
| 30 |
+
-0.09105003625154495f,
|
| 31 |
+
0.0f,
|
| 32 |
+
0.07958029955625534f,
|
| 33 |
+
0.16093020141124725f,
|
| 34 |
+
0.24611230194568634f,
|
| 35 |
+
0.33791524171829224f,
|
| 36 |
+
0.44070982933044434f,
|
| 37 |
+
0.5626170039176941f,
|
| 38 |
+
0.7229568362236023f,
|
| 39 |
+
1.0f,
|
| 40 |
+
};
|
| 41 |
+
|
| 42 |
+
// ============================================================================
|
| 43 |
+
// FP4 codebook - 16 values using sign-magnitude FP4 encoding
|
| 44 |
+
// Indices 0-7: non-negative, indices 8-15: negative (bit 3 = sign)
|
| 45 |
+
// ============================================================================
|
| 46 |
+
|
| 47 |
+
constant float FP4_CODEBOOK[16] = {
|
| 48 |
+
0.0f,
|
| 49 |
+
0.005208333333f,
|
| 50 |
+
0.66666667f,
|
| 51 |
+
1.0f,
|
| 52 |
+
0.33333333f,
|
| 53 |
+
0.5f,
|
| 54 |
+
0.16666667f,
|
| 55 |
+
0.25f,
|
| 56 |
+
0.0f,
|
| 57 |
+
-0.005208333333f,
|
| 58 |
+
-0.66666667f,
|
| 59 |
+
-1.0f,
|
| 60 |
+
-0.33333333f,
|
| 61 |
+
-0.5f,
|
| 62 |
+
-0.16666667f,
|
| 63 |
+
-0.25f,
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
// ============================================================================
|
| 67 |
+
// Codebook accessor by quant_type template parameter
|
| 68 |
+
// ============================================================================
|
| 69 |
+
|
| 70 |
+
template <int quant_type>
|
| 71 |
+
inline constant float* bnb_codebook() {
|
| 72 |
+
if (quant_type == BNB_NF4) {
|
| 73 |
+
return NF4_CODEBOOK;
|
| 74 |
+
} else {
|
| 75 |
+
return FP4_CODEBOOK;
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
// ============================================================================
|
| 80 |
+
// NF4 quantization - binary search (matches CUDA dQuantizeNF4)
|
| 81 |
+
// Input: normalized value in [-1, 1]
|
| 82 |
+
// Output: 4-bit index (0-15)
|
| 83 |
+
// ============================================================================
|
| 84 |
+
|
| 85 |
+
inline uchar quantize_nf4(float x) {
|
| 86 |
+
if (x > 0.03979014977812767f) {
|
| 87 |
+
if (x > 0.3893125355243683f) {
|
| 88 |
+
if (x > 0.6427869200706482f) {
|
| 89 |
+
return (x > 0.8614784181118011f) ? 15 : 14;
|
| 90 |
+
}
|
| 91 |
+
return (x > 0.5016634166240692f) ? 13 : 12;
|
| 92 |
+
}
|
| 93 |
+
if (x > 0.2035212516784668f) {
|
| 94 |
+
return (x > 0.2920137718319893f) ? 11 : 10;
|
| 95 |
+
}
|
| 96 |
+
return (x > 0.1202552504837513f) ? 9 : 8;
|
| 97 |
+
}
|
| 98 |
+
if (x > -0.33967943489551544f) {
|
| 99 |
+
if (x > -0.13791173323988914f) {
|
| 100 |
+
return (x > -0.045525018125772476f) ? 7 : 6;
|
| 101 |
+
}
|
| 102 |
+
return (x > -0.23460740596055984f) ? 5 : 4;
|
| 103 |
+
}
|
| 104 |
+
if (x > -0.6106329262256622f) {
|
| 105 |
+
return (x > -0.4599952697753906f) ? 3 : 2;
|
| 106 |
+
}
|
| 107 |
+
return (x > -0.8480964004993439f) ? 1 : 0;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// ============================================================================
|
| 111 |
+
// FP4 quantization - binary search (matches CUDA dQuantizeFP4)
|
| 112 |
+
// Input: normalized value in [-1, 1]
|
| 113 |
+
// Output: 4-bit index (0-15), MSB = sign bit
|
| 114 |
+
// ============================================================================
|
| 115 |
+
|
| 116 |
+
inline uchar quantize_fp4(float x) {
|
| 117 |
+
uchar sign = (x < 0.0f) ? 8 : 0;
|
| 118 |
+
x = metal::abs(x);
|
| 119 |
+
uchar code;
|
| 120 |
+
if (x > 0.29166667f) {
|
| 121 |
+
if (x > 0.75f) {
|
| 122 |
+
code = (x > 0.8333333f) ? 3 : 2;
|
| 123 |
+
} else {
|
| 124 |
+
code = (x > 0.4166667f) ? 5 : 4;
|
| 125 |
+
}
|
| 126 |
+
} else {
|
| 127 |
+
if (x > 0.0859375f) {
|
| 128 |
+
code = (x > 0.20833333f) ? 7 : 6;
|
| 129 |
+
} else {
|
| 130 |
+
code = (x > 0.00260416f) ? 1 : 0;
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
return sign | code;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// ============================================================================
|
| 137 |
+
// Generic quantize dispatch by quant_type
|
| 138 |
+
// ============================================================================
|
| 139 |
+
|
| 140 |
+
template <int quant_type>
|
| 141 |
+
inline uchar bnb_quantize_value(float normalized) {
|
| 142 |
+
if (quant_type == BNB_NF4) {
|
| 143 |
+
return quantize_nf4(normalized);
|
| 144 |
+
} else {
|
| 145 |
+
return quantize_fp4(normalized);
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// ============================================================================
|
| 150 |
+
// Dequantize a single 4-bit value using codebook lookup
|
| 151 |
+
// ============================================================================
|
| 152 |
+
|
| 153 |
+
template <int quant_type>
|
| 154 |
+
inline float bnb_dequantize_value(uchar nibble) {
|
| 155 |
+
return bnb_codebook<quant_type>()[nibble & 0x0f];
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// ============================================================================
|
| 159 |
+
// BnB 4-bit dequantize for block loader (adapted from MLX affine dequantize)
|
| 160 |
+
// Unpacks N values from packed bytes using codebook lookup.
|
| 161 |
+
//
|
| 162 |
+
// BnB packing: high nibble = first element, low nibble = second element
|
| 163 |
+
// Each byte stores 2 4-bit values.
|
| 164 |
+
// ============================================================================
|
| 165 |
+
|
| 166 |
+
template <typename U, int N, int quant_type>
|
| 167 |
+
inline void bnb_dequantize(
|
| 168 |
+
const device uint8_t* w,
|
| 169 |
+
U absmax_val,
|
| 170 |
+
threadgroup U* w_local) {
|
| 171 |
+
constant float* codebook = bnb_codebook<quant_type>();
|
| 172 |
+
|
| 173 |
+
for (int i = 0; i < N / 2; i++) {
|
| 174 |
+
uint8_t byte_val = w[i];
|
| 175 |
+
uint8_t high = (byte_val >> 4) & 0x0f;
|
| 176 |
+
uint8_t low = byte_val & 0x0f;
|
| 177 |
+
w_local[2 * i] = U(codebook[high]) * absmax_val;
|
| 178 |
+
w_local[2 * i + 1] = U(codebook[low]) * absmax_val;
|
| 179 |
+
}
|
| 180 |
+
}
|
bitsandbytes_mps/complex.h
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
|
| 7 |
+
using namespace metal;
|
| 8 |
+
|
| 9 |
+
struct complex64_t;
|
| 10 |
+
|
| 11 |
+
template <typename T>
|
| 12 |
+
static constexpr constant bool can_convert_to_complex64 =
|
| 13 |
+
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
| 14 |
+
|
| 15 |
+
template <typename T>
|
| 16 |
+
static constexpr constant bool can_convert_from_complex64 =
|
| 17 |
+
!is_same_v<T, complex64_t> &&
|
| 18 |
+
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
| 19 |
+
|
| 20 |
+
struct complex64_t {
|
| 21 |
+
float real;
|
| 22 |
+
float imag;
|
| 23 |
+
|
| 24 |
+
// Constructors
|
| 25 |
+
constexpr complex64_t(float real, float imag) : real(real), imag(imag) {};
|
| 26 |
+
constexpr complex64_t() : real(0), imag(0) {};
|
| 27 |
+
constexpr complex64_t() threadgroup : real(0), imag(0) {};
|
| 28 |
+
|
| 29 |
+
// Conversions to complex64_t
|
| 30 |
+
template <
|
| 31 |
+
typename T,
|
| 32 |
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
| 33 |
+
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
| 34 |
+
|
| 35 |
+
template <
|
| 36 |
+
typename T,
|
| 37 |
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
| 38 |
+
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
| 39 |
+
|
| 40 |
+
template <
|
| 41 |
+
typename T,
|
| 42 |
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
| 43 |
+
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
| 44 |
+
|
| 45 |
+
template <
|
| 46 |
+
typename T,
|
| 47 |
+
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
| 48 |
+
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
| 49 |
+
|
| 50 |
+
// Conversions from complex64_t
|
| 51 |
+
template <
|
| 52 |
+
typename T,
|
| 53 |
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
| 54 |
+
constexpr operator T() const thread {
|
| 55 |
+
return static_cast<T>(real);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
template <
|
| 59 |
+
typename T,
|
| 60 |
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
| 61 |
+
constexpr operator T() const threadgroup {
|
| 62 |
+
return static_cast<T>(real);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
template <
|
| 66 |
+
typename T,
|
| 67 |
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
| 68 |
+
constexpr operator T() const device {
|
| 69 |
+
return static_cast<T>(real);
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
template <
|
| 73 |
+
typename T,
|
| 74 |
+
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
| 75 |
+
constexpr operator T() const constant {
|
| 76 |
+
return static_cast<T>(real);
|
| 77 |
+
}
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
constexpr complex64_t operator-(complex64_t x) {
|
| 81 |
+
return {-x.real, -x.imag};
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
| 85 |
+
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
constexpr bool operator>(complex64_t a, complex64_t b) {
|
| 89 |
+
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
| 93 |
+
return operator>=(b, a);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
constexpr bool operator<(complex64_t a, complex64_t b) {
|
| 97 |
+
return operator>(b, a);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
constexpr bool operator==(complex64_t a, complex64_t b) {
|
| 101 |
+
return a.real == b.real && a.imag == b.imag;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
| 105 |
+
return {a.real + b.real, a.imag + b.imag};
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) {
|
| 109 |
+
a.real += b.real;
|
| 110 |
+
a.imag += b.imag;
|
| 111 |
+
return a;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
constexpr threadgroup complex64_t& operator+=(
|
| 115 |
+
threadgroup complex64_t& a,
|
| 116 |
+
complex64_t b) {
|
| 117 |
+
a.real += b.real;
|
| 118 |
+
a.imag += b.imag;
|
| 119 |
+
return a;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) {
|
| 123 |
+
a.real += b.real;
|
| 124 |
+
a.imag += b.imag;
|
| 125 |
+
return a;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
constexpr complex64_t operator+(float a, complex64_t b) {
|
| 129 |
+
return {a + b.real, b.imag};
|
| 130 |
+
}
|
| 131 |
+
constexpr complex64_t operator+(complex64_t a, float b) {
|
| 132 |
+
return {a.real + b, a.imag};
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
| 136 |
+
return {a.real - b.real, a.imag - b.imag};
|
| 137 |
+
}
|
| 138 |
+
constexpr complex64_t operator-(float a, complex64_t b) {
|
| 139 |
+
return {a - b.real, -b.imag};
|
| 140 |
+
}
|
| 141 |
+
constexpr complex64_t operator-(complex64_t a, float b) {
|
| 142 |
+
return {a.real - b, a.imag};
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
| 146 |
+
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
| 150 |
+
auto denom = b.real * b.real + b.imag * b.imag;
|
| 151 |
+
auto x = a.real * b.real + a.imag * b.imag;
|
| 152 |
+
auto y = a.imag * b.real - a.real * b.imag;
|
| 153 |
+
return {x / denom, y / denom};
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
constexpr complex64_t operator/(float a, complex64_t b) {
|
| 157 |
+
auto denom = b.real * b.real + b.imag * b.imag;
|
| 158 |
+
auto x = a * b.real;
|
| 159 |
+
auto y = -a * b.imag;
|
| 160 |
+
return {x / denom, y / denom};
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
| 164 |
+
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
| 165 |
+
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
| 166 |
+
if (real != 0 && (real < 0 != b.real < 0)) {
|
| 167 |
+
real += b.real;
|
| 168 |
+
}
|
| 169 |
+
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
| 170 |
+
imag += b.imag;
|
| 171 |
+
}
|
| 172 |
+
return {real, imag};
|
| 173 |
+
}
|
bitsandbytes_mps/defines.h
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#if defined __METAL__ || defined MLX_METAL_JIT
|
| 6 |
+
#define MTL_CONST constant
|
| 7 |
+
#else
|
| 8 |
+
#define MTL_CONST
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4;
|
| 12 |
+
static MTL_CONST constexpr int REDUCE_N_READS = 4;
|
| 13 |
+
static MTL_CONST constexpr int REDUCE_N_WRITES = 4;
|
| 14 |
+
static MTL_CONST constexpr int SOFTMAX_N_READS = 4;
|
| 15 |
+
static MTL_CONST constexpr int RMS_N_READS = 4;
|
| 16 |
+
static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096;
|
| 17 |
+
|
| 18 |
+
// Instantiate a templated kernel.
|
| 19 |
+
// Extra args are used as template parameters:
|
| 20 |
+
// e.g. instantiate_kernel(binary_int, binary, a, b) ->
|
| 21 |
+
// [[host_name(binary_int)]] [kernel] binary<a, b>
|
| 22 |
+
#define instantiate_kernel(name, func, ...) \
|
| 23 |
+
template [[host_name( \
|
| 24 |
+
name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>;
|
bitsandbytes_mps/gemm/defines.h
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#define STEEL_CONST static constant constexpr const
|
| 4 |
+
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
| 5 |
+
#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)")
|
bitsandbytes_mps/gemm/gemm.h
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "gemm/loader.h"
|
| 6 |
+
#include "gemm/mma.h"
|
| 7 |
+
#include "gemm/params.h"
|
| 8 |
+
#include "gemm/transforms.h"
|
| 9 |
+
#include "gemm/utils.h"
|
| 10 |
+
|
| 11 |
+
using namespace metal;
|
| 12 |
+
|
| 13 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 14 |
+
// GEMM kernel class
|
| 15 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
|
| 17 |
+
namespace mlx {
|
| 18 |
+
namespace steel {
|
| 19 |
+
|
| 20 |
+
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
| 21 |
+
struct LoopAlignment {};
|
| 22 |
+
|
| 23 |
+
template <
|
| 24 |
+
typename T,
|
| 25 |
+
typename U,
|
| 26 |
+
int BM,
|
| 27 |
+
int BN,
|
| 28 |
+
int BK,
|
| 29 |
+
int WM,
|
| 30 |
+
int WN,
|
| 31 |
+
bool transpose_a,
|
| 32 |
+
bool transpose_b,
|
| 33 |
+
bool MN_aligned,
|
| 34 |
+
bool K_aligned,
|
| 35 |
+
typename AccumType = typename AccumHelper<T>::accum_type,
|
| 36 |
+
typename Epilogue = TransformNone<U, AccumType>>
|
| 37 |
+
struct GEMMKernel {
|
| 38 |
+
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
| 39 |
+
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
| 40 |
+
STEEL_CONST short tgp_mem_size_a =
|
| 41 |
+
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
| 42 |
+
STEEL_CONST short tgp_mem_size_b =
|
| 43 |
+
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
| 44 |
+
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
| 45 |
+
|
| 46 |
+
STEEL_CONST short tgp_size = WM * WN * 32;
|
| 47 |
+
|
| 48 |
+
using loader_a_t = BlockLoader<
|
| 49 |
+
T,
|
| 50 |
+
transpose_a ? BK : BM,
|
| 51 |
+
transpose_a ? BM : BK,
|
| 52 |
+
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
| 53 |
+
!transpose_a,
|
| 54 |
+
tgp_size>;
|
| 55 |
+
using loader_b_t = BlockLoader<
|
| 56 |
+
T,
|
| 57 |
+
transpose_b ? BN : BK,
|
| 58 |
+
transpose_b ? BK : BN,
|
| 59 |
+
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
| 60 |
+
transpose_b,
|
| 61 |
+
tgp_size>;
|
| 62 |
+
using mma_t = BlockMMA<
|
| 63 |
+
T,
|
| 64 |
+
U,
|
| 65 |
+
BM,
|
| 66 |
+
BN,
|
| 67 |
+
BK,
|
| 68 |
+
WM,
|
| 69 |
+
WN,
|
| 70 |
+
transpose_a,
|
| 71 |
+
transpose_b,
|
| 72 |
+
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
| 73 |
+
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
| 74 |
+
AccumType,
|
| 75 |
+
Epilogue>;
|
| 76 |
+
|
| 77 |
+
/* Main kernel function */
|
| 78 |
+
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
| 79 |
+
static METAL_FUNC void gemm_loop(
|
| 80 |
+
threadgroup T* As [[threadgroup(0)]],
|
| 81 |
+
threadgroup T* Bs [[threadgroup(1)]],
|
| 82 |
+
const int gemm_k_iterations,
|
| 83 |
+
thread loader_a_t& loader_a,
|
| 84 |
+
thread loader_b_t& loader_b,
|
| 85 |
+
thread mma_t& mma_op,
|
| 86 |
+
thread const short& tgp_bm,
|
| 87 |
+
thread const short& tgp_bn,
|
| 88 |
+
thread const short& lbk,
|
| 89 |
+
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
| 90 |
+
// Appease the compiler
|
| 91 |
+
(void)l;
|
| 92 |
+
|
| 93 |
+
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
| 94 |
+
|
| 95 |
+
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
| 96 |
+
|
| 97 |
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
| 98 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 99 |
+
// Load elements into threadgroup
|
| 100 |
+
if (M_aligned) {
|
| 101 |
+
loader_a.load_unsafe();
|
| 102 |
+
} else {
|
| 103 |
+
loader_a.load_safe(tile_dims_A);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if (N_aligned) {
|
| 107 |
+
loader_b.load_unsafe();
|
| 108 |
+
} else {
|
| 109 |
+
loader_b.load_safe(tile_dims_B);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 113 |
+
|
| 114 |
+
// Multiply and accumulate threadgroup elements
|
| 115 |
+
mma_op.mma(As, Bs);
|
| 116 |
+
|
| 117 |
+
// Prepare for next iteration
|
| 118 |
+
loader_a.next();
|
| 119 |
+
loader_b.next();
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
if (!K_aligned_) {
|
| 123 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 124 |
+
|
| 125 |
+
short2 tile_dims_A_last =
|
| 126 |
+
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
| 127 |
+
short2 tile_dims_B_last =
|
| 128 |
+
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
| 129 |
+
|
| 130 |
+
loader_a.load_safe(tile_dims_A_last);
|
| 131 |
+
loader_b.load_safe(tile_dims_B_last);
|
| 132 |
+
|
| 133 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 134 |
+
|
| 135 |
+
mma_op.mma(As, Bs);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/* Main kernel function */
|
| 140 |
+
static METAL_FUNC void run(
|
| 141 |
+
const device T* A [[buffer(0)]],
|
| 142 |
+
const device T* B [[buffer(1)]],
|
| 143 |
+
device U* D [[buffer(2)]],
|
| 144 |
+
const constant GEMMParams* params [[buffer(3)]],
|
| 145 |
+
threadgroup T* As [[threadgroup(0)]],
|
| 146 |
+
threadgroup T* Bs [[threadgroup(1)]],
|
| 147 |
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
| 148 |
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 149 |
+
uint3 tid [[threadgroup_position_in_grid]],
|
| 150 |
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
| 151 |
+
// Pacifying compiler
|
| 152 |
+
(void)lid;
|
| 153 |
+
|
| 154 |
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
| 155 |
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
| 156 |
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
| 157 |
+
|
| 158 |
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
| 159 |
+
return;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 163 |
+
|
| 164 |
+
// Find block in A, B, C
|
| 165 |
+
const int c_row = tid_y * BM;
|
| 166 |
+
const int c_col = tid_x * BN;
|
| 167 |
+
const size_t c_row_long = size_t(c_row);
|
| 168 |
+
const size_t c_col_long = size_t(c_col);
|
| 169 |
+
|
| 170 |
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
| 171 |
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
| 172 |
+
D += c_row_long * params->ldd + c_col_long;
|
| 173 |
+
|
| 174 |
+
// Prepare threadgroup loading operations
|
| 175 |
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
| 176 |
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
| 177 |
+
|
| 178 |
+
// Prepare threadgroup mma operation
|
| 179 |
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
| 180 |
+
|
| 181 |
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
| 182 |
+
|
| 183 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 184 |
+
// MNK aligned loop
|
| 185 |
+
if (MN_aligned) {
|
| 186 |
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
| 187 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 188 |
+
// Load elements into threadgroup
|
| 189 |
+
loader_a.load_unsafe();
|
| 190 |
+
loader_b.load_unsafe();
|
| 191 |
+
|
| 192 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 193 |
+
|
| 194 |
+
// Multiply and accumulate threadgroup elements
|
| 195 |
+
mma_op.mma(As, Bs);
|
| 196 |
+
|
| 197 |
+
// Prepare for next iteration
|
| 198 |
+
loader_a.next();
|
| 199 |
+
loader_b.next();
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
threadgroup_barrier(mem_flags::mem_none);
|
| 203 |
+
|
| 204 |
+
// Loop tail
|
| 205 |
+
if (!K_aligned) {
|
| 206 |
+
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
| 207 |
+
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
| 208 |
+
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
| 209 |
+
|
| 210 |
+
loader_a.load_safe(tile_dims_A);
|
| 211 |
+
loader_b.load_safe(tile_dims_B);
|
| 212 |
+
|
| 213 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 214 |
+
|
| 215 |
+
mma_op.mma(As, Bs);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// Store results to device memory
|
| 219 |
+
mma_op.store_result(D, params->ldd);
|
| 220 |
+
return;
|
| 221 |
+
|
| 222 |
+
}
|
| 223 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 224 |
+
// MN unaligned loop
|
| 225 |
+
else { // Loop over K - unaligned case
|
| 226 |
+
short tgp_bm = min(BM, params->M - c_row);
|
| 227 |
+
short tgp_bn = min(BN, params->N - c_col);
|
| 228 |
+
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
| 229 |
+
|
| 230 |
+
if (tgp_bm == BM && tgp_bn == BN) {
|
| 231 |
+
gemm_loop<true, true, K_aligned>(
|
| 232 |
+
As,
|
| 233 |
+
Bs,
|
| 234 |
+
gemm_k_iterations,
|
| 235 |
+
loader_a,
|
| 236 |
+
loader_b,
|
| 237 |
+
mma_op,
|
| 238 |
+
tgp_bm,
|
| 239 |
+
tgp_bn,
|
| 240 |
+
leftover_bk);
|
| 241 |
+
|
| 242 |
+
mma_op.store_result(D, params->ldd);
|
| 243 |
+
return;
|
| 244 |
+
|
| 245 |
+
} else if (tgp_bn == BN) {
|
| 246 |
+
gemm_loop<false, true, K_aligned>(
|
| 247 |
+
As,
|
| 248 |
+
Bs,
|
| 249 |
+
gemm_k_iterations,
|
| 250 |
+
loader_a,
|
| 251 |
+
loader_b,
|
| 252 |
+
mma_op,
|
| 253 |
+
tgp_bm,
|
| 254 |
+
tgp_bn,
|
| 255 |
+
leftover_bk);
|
| 256 |
+
|
| 257 |
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
| 258 |
+
return;
|
| 259 |
+
|
| 260 |
+
} else if (tgp_bm == BM) {
|
| 261 |
+
gemm_loop<true, false, K_aligned>(
|
| 262 |
+
As,
|
| 263 |
+
Bs,
|
| 264 |
+
gemm_k_iterations,
|
| 265 |
+
loader_a,
|
| 266 |
+
loader_b,
|
| 267 |
+
mma_op,
|
| 268 |
+
tgp_bm,
|
| 269 |
+
tgp_bn,
|
| 270 |
+
leftover_bk);
|
| 271 |
+
|
| 272 |
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
| 273 |
+
return;
|
| 274 |
+
|
| 275 |
+
} else {
|
| 276 |
+
gemm_loop<false, false, K_aligned>(
|
| 277 |
+
As,
|
| 278 |
+
Bs,
|
| 279 |
+
gemm_k_iterations,
|
| 280 |
+
loader_a,
|
| 281 |
+
loader_b,
|
| 282 |
+
mma_op,
|
| 283 |
+
tgp_bm,
|
| 284 |
+
tgp_bn,
|
| 285 |
+
leftover_bk);
|
| 286 |
+
|
| 287 |
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
| 288 |
+
return;
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
};
|
| 293 |
+
|
| 294 |
+
} // namespace steel
|
| 295 |
+
} // namespace mlx
|
bitsandbytes_mps/gemm/loader.h
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "gemm/defines.h"
|
| 6 |
+
|
| 7 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
// Loading helper
|
| 9 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 10 |
+
|
| 11 |
+
namespace mlx {
|
| 12 |
+
namespace steel {
|
| 13 |
+
|
| 14 |
+
template <
|
| 15 |
+
typename T,
|
| 16 |
+
short BROWS,
|
| 17 |
+
short BCOLS,
|
| 18 |
+
short dst_ld,
|
| 19 |
+
short reduction_dim,
|
| 20 |
+
short tgp_size,
|
| 21 |
+
short alignment = 1,
|
| 22 |
+
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
| 23 |
+
short TCOLS = BCOLS / n_reads,
|
| 24 |
+
short TROWS = tgp_size / TCOLS>
|
| 25 |
+
struct BlockLoader {
|
| 26 |
+
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
| 27 |
+
STEEL_CONST short vec_size = n_reads;
|
| 28 |
+
|
| 29 |
+
// Leading dimension for src
|
| 30 |
+
const int src_ld;
|
| 31 |
+
const int tile_stride;
|
| 32 |
+
|
| 33 |
+
// Thread location indices
|
| 34 |
+
const short thread_idx;
|
| 35 |
+
const short bi;
|
| 36 |
+
const short bj;
|
| 37 |
+
|
| 38 |
+
// threadgroup and device memory
|
| 39 |
+
threadgroup T* dst;
|
| 40 |
+
const device T* src;
|
| 41 |
+
|
| 42 |
+
struct alignas(alignment * sizeof(T)) ReadVector {
|
| 43 |
+
uint8_t v[sizeof(T) * vec_size];
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
/* Constructor */
|
| 47 |
+
METAL_FUNC BlockLoader(
|
| 48 |
+
const device T* src_,
|
| 49 |
+
const int src_ld_,
|
| 50 |
+
threadgroup T* dst_,
|
| 51 |
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 52 |
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
| 53 |
+
: src_ld(src_ld_),
|
| 54 |
+
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
| 55 |
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
| 56 |
+
bi(thread_idx / TCOLS),
|
| 57 |
+
bj(vec_size * (thread_idx % TCOLS)),
|
| 58 |
+
dst(dst_ + bi * dst_ld + bj),
|
| 59 |
+
src(src_ + bi * src_ld + bj) {}
|
| 60 |
+
|
| 61 |
+
/* Apply operation to threadgroup without bound checking */
|
| 62 |
+
template <typename UnaryOp>
|
| 63 |
+
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
| 64 |
+
STEEL_PRAGMA_UNROLL
|
| 65 |
+
for (short i = 0; i < BROWS; i += TROWS) {
|
| 66 |
+
STEEL_PRAGMA_UNROLL
|
| 67 |
+
for (short j = 0; j < vec_size; j++) {
|
| 68 |
+
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
/* Load from device memory into threadgroup memory - without bound checking */
|
| 74 |
+
METAL_FUNC void load_unsafe() const {
|
| 75 |
+
STEEL_PRAGMA_UNROLL
|
| 76 |
+
for (short i = 0; i < BROWS; i += TROWS) {
|
| 77 |
+
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
| 78 |
+
*((const device ReadVector*)(&src[i * src_ld]));
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/* Load from device memory into threadgroup memory - with bound checking */
|
| 83 |
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
| 84 |
+
src_tile_dim = src_tile_dim - short2(bj, bi);
|
| 85 |
+
|
| 86 |
+
// Skip loading if thread has no valid reads
|
| 87 |
+
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
| 88 |
+
STEEL_PRAGMA_UNROLL
|
| 89 |
+
for (short i = 0; i < BROWS; i += TROWS) {
|
| 90 |
+
STEEL_PRAGMA_UNROLL
|
| 91 |
+
for (short j = 0; j < vec_size; j++) {
|
| 92 |
+
dst[i * dst_ld + j] = T(0);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
return;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
// Use fast thread memory for bound checks
|
| 99 |
+
bool tmp_idx[vec_size];
|
| 100 |
+
T tmp_val[vec_size];
|
| 101 |
+
|
| 102 |
+
STEEL_PRAGMA_UNROLL
|
| 103 |
+
for (short i = 0; i < BROWS; i += TROWS) {
|
| 104 |
+
// Make sure tmp_idx only contains valid indices
|
| 105 |
+
STEEL_PRAGMA_UNROLL
|
| 106 |
+
for (short j = 0; j < vec_size; j++) {
|
| 107 |
+
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Read valid indices into tmp_val
|
| 111 |
+
STEEL_PRAGMA_UNROLL
|
| 112 |
+
for (short j = 0; j < vec_size; j++) {
|
| 113 |
+
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Zero out unneeded values
|
| 117 |
+
STEEL_PRAGMA_UNROLL
|
| 118 |
+
for (short j = 0; j < vec_size; j++) {
|
| 119 |
+
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
// Copy values to threadgroup memory
|
| 123 |
+
STEEL_PRAGMA_UNROLL
|
| 124 |
+
for (short j = 0; j < vec_size; j++) {
|
| 125 |
+
dst[i * dst_ld + j] = tmp_val[j];
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/* Iteration helper */
|
| 131 |
+
METAL_FUNC void next() {
|
| 132 |
+
src += tile_stride;
|
| 133 |
+
}
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
} // namespace steel
|
| 137 |
+
} // namespace mlx
|
bitsandbytes_mps/gemm/mma.h
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_simdgroup>
|
| 6 |
+
#include <metal_simdgroup_matrix>
|
| 7 |
+
#include <metal_stdlib>
|
| 8 |
+
|
| 9 |
+
#include "gemm/defines.h"
|
| 10 |
+
#include "gemm/transforms.h"
|
| 11 |
+
#include "gemm/utils/integral_constant.h"
|
| 12 |
+
|
| 13 |
+
using namespace metal;
|
| 14 |
+
|
| 15 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
// MMA helper
|
| 17 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 18 |
+
|
| 19 |
+
namespace mlx {
|
| 20 |
+
namespace steel {
|
| 21 |
+
|
| 22 |
+
template <typename T, int kFragRows_, int kFragCols_>
|
| 23 |
+
struct BaseMMAFrag {
|
| 24 |
+
static_assert(
|
| 25 |
+
kFragRows_ == 8,
|
| 26 |
+
"Only 8 x 8 fragment matrices are currently supported");
|
| 27 |
+
static_assert(
|
| 28 |
+
kFragCols_ == 8,
|
| 29 |
+
"Only 8 x 8 fragment matrices are currently supported");
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
template <typename T>
|
| 33 |
+
struct BaseMMAFrag<T, 8, 8> {
|
| 34 |
+
STEEL_CONST int kFragRows = 8;
|
| 35 |
+
STEEL_CONST int kFragCols = 8;
|
| 36 |
+
|
| 37 |
+
STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32;
|
| 38 |
+
|
| 39 |
+
STEEL_CONST int kElemRows = 1;
|
| 40 |
+
STEEL_CONST int kElemCols = 2;
|
| 41 |
+
|
| 42 |
+
static_assert(
|
| 43 |
+
kElemRows * kElemCols == kElemsPerFrag,
|
| 44 |
+
"MMAFrag shape is not consistent with MMAFrag size");
|
| 45 |
+
|
| 46 |
+
typedef metal::simdgroup_matrix<T, kFragRows, kFragCols> mat_type;
|
| 47 |
+
typedef metal::vec<T, kElemsPerFrag> frag_type;
|
| 48 |
+
|
| 49 |
+
METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id
|
| 50 |
+
[[thread_index_in_simdgroup]]) {
|
| 51 |
+
const short qid = simd_lane_id / 4;
|
| 52 |
+
const short fm = (qid & 4) + ((simd_lane_id / 2) % 4);
|
| 53 |
+
const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
| 54 |
+
return short2{fn, fm};
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
template <typename SrcPtrType, typename StrX, typename StrY>
|
| 58 |
+
METAL_FUNC static constexpr void
|
| 59 |
+
load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) {
|
| 60 |
+
STEEL_PRAGMA_UNROLL
|
| 61 |
+
for (short i = 0; i < kElemRows; i++) {
|
| 62 |
+
STEEL_PRAGMA_UNROLL
|
| 63 |
+
for (short j = 0; j < kElemCols; j++) {
|
| 64 |
+
dst[i * kElemCols + j] = static_cast<T>(src[i * str_x + j * str_y]);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
template <
|
| 70 |
+
typename SrcPtrType,
|
| 71 |
+
typename StrX,
|
| 72 |
+
typename StrY,
|
| 73 |
+
typename LimX,
|
| 74 |
+
typename LimY,
|
| 75 |
+
typename OffX,
|
| 76 |
+
typename OffY>
|
| 77 |
+
METAL_FUNC static constexpr void load_safe(
|
| 78 |
+
thread frag_type& dst,
|
| 79 |
+
SrcPtrType src,
|
| 80 |
+
StrX str_x,
|
| 81 |
+
StrY str_y,
|
| 82 |
+
LimX lim_x,
|
| 83 |
+
LimY lim_y,
|
| 84 |
+
OffX off_x = Int<0>{},
|
| 85 |
+
OffY off_y = Int<0>{}) {
|
| 86 |
+
STEEL_PRAGMA_UNROLL
|
| 87 |
+
for (short i = 0; i < kElemRows; i++) {
|
| 88 |
+
STEEL_PRAGMA_UNROLL
|
| 89 |
+
for (short j = 0; j < kElemCols; j++) {
|
| 90 |
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
| 91 |
+
dst[i * kElemCols + j] =
|
| 92 |
+
static_cast<T>(src[(off_x + i) * str_x + (off_x + j) * str_y]);
|
| 93 |
+
} else {
|
| 94 |
+
dst[i * kElemCols + j] = T(0);
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
template <typename DstPtrType, typename StrX, typename StrY>
|
| 101 |
+
METAL_FUNC static constexpr void
|
| 102 |
+
store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) {
|
| 103 |
+
using U = pointer_element_t<DstPtrType>;
|
| 104 |
+
|
| 105 |
+
STEEL_PRAGMA_UNROLL
|
| 106 |
+
for (short i = 0; i < kElemRows; i++) {
|
| 107 |
+
STEEL_PRAGMA_UNROLL
|
| 108 |
+
for (short j = 0; j < kElemCols; j++) {
|
| 109 |
+
dst[i * str_x + j * str_y] = static_cast<U>(src[i * kElemCols + j]);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
template <
|
| 115 |
+
typename DstPtrType,
|
| 116 |
+
typename StrX,
|
| 117 |
+
typename StrY,
|
| 118 |
+
typename LimX,
|
| 119 |
+
typename LimY,
|
| 120 |
+
typename OffX,
|
| 121 |
+
typename OffY>
|
| 122 |
+
METAL_FUNC static constexpr void store_safe(
|
| 123 |
+
const thread frag_type& src,
|
| 124 |
+
DstPtrType dst,
|
| 125 |
+
StrX str_x,
|
| 126 |
+
StrY str_y,
|
| 127 |
+
LimX lim_x,
|
| 128 |
+
LimY lim_y,
|
| 129 |
+
OffX off_x = Int<0>{},
|
| 130 |
+
OffY off_y = Int<0>{}) {
|
| 131 |
+
using U = pointer_element_t<DstPtrType>;
|
| 132 |
+
|
| 133 |
+
STEEL_PRAGMA_UNROLL
|
| 134 |
+
for (short i = 0; i < kElemRows; i++) {
|
| 135 |
+
STEEL_PRAGMA_UNROLL
|
| 136 |
+
for (short j = 0; j < kElemCols; j++) {
|
| 137 |
+
if ((off_x + i) < lim_x && (off_y + j) < lim_y) {
|
| 138 |
+
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
| 139 |
+
static_cast<U>(src[i * kElemCols + j]);
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <
|
| 146 |
+
typename DstPtrType,
|
| 147 |
+
typename StrX,
|
| 148 |
+
typename StrY,
|
| 149 |
+
typename StartX,
|
| 150 |
+
typename StopX,
|
| 151 |
+
typename StartY,
|
| 152 |
+
typename StopY,
|
| 153 |
+
typename OffX,
|
| 154 |
+
typename OffY>
|
| 155 |
+
METAL_FUNC static constexpr void store_slice(
|
| 156 |
+
const thread frag_type& src,
|
| 157 |
+
DstPtrType dst,
|
| 158 |
+
StrX str_x,
|
| 159 |
+
StrY str_y,
|
| 160 |
+
StartX start_x,
|
| 161 |
+
StopX stop_x,
|
| 162 |
+
StartY start_y,
|
| 163 |
+
StopY stop_y,
|
| 164 |
+
OffX off_x = Int<0>{},
|
| 165 |
+
OffY off_y = Int<0>{}) {
|
| 166 |
+
using U = pointer_element_t<DstPtrType>;
|
| 167 |
+
|
| 168 |
+
STEEL_PRAGMA_UNROLL
|
| 169 |
+
for (short i = 0; i < kElemRows; i++) {
|
| 170 |
+
STEEL_PRAGMA_UNROLL
|
| 171 |
+
for (short j = 0; j < kElemCols; j++) {
|
| 172 |
+
if ((off_x + i) < stop_x && (off_x + i) >= start_x &&
|
| 173 |
+
(off_y + j) < stop_y && (off_y + j) >= start_y) {
|
| 174 |
+
dst[(off_x + i) * str_x + (off_y + j) * str_y] =
|
| 175 |
+
static_cast<U>(src[i * kElemCols + j]);
|
| 176 |
+
}
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
METAL_FUNC static constexpr void mma(
|
| 182 |
+
thread frag_type& D,
|
| 183 |
+
thread frag_type& A,
|
| 184 |
+
thread frag_type& B,
|
| 185 |
+
thread frag_type& C) {
|
| 186 |
+
mat_type D_mat;
|
| 187 |
+
mat_type A_mat;
|
| 188 |
+
mat_type B_mat;
|
| 189 |
+
mat_type C_mat;
|
| 190 |
+
|
| 191 |
+
reinterpret_cast<thread frag_type&>(A_mat.thread_elements()) = A;
|
| 192 |
+
reinterpret_cast<thread frag_type&>(B_mat.thread_elements()) = B;
|
| 193 |
+
reinterpret_cast<thread frag_type&>(C_mat.thread_elements()) = C;
|
| 194 |
+
|
| 195 |
+
mma(D_mat, A_mat, B_mat, C_mat);
|
| 196 |
+
|
| 197 |
+
D = reinterpret_cast<thread frag_type&>(D_mat.thread_elements());
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
METAL_FUNC static constexpr void mma(
|
| 201 |
+
thread mat_type& D,
|
| 202 |
+
thread mat_type& A,
|
| 203 |
+
thread mat_type& B,
|
| 204 |
+
thread mat_type& C) {
|
| 205 |
+
simdgroup_multiply_accumulate(D, A, B, C);
|
| 206 |
+
}
|
| 207 |
+
};
|
| 208 |
+
|
| 209 |
+
template <
|
| 210 |
+
typename T,
|
| 211 |
+
int kTileRows_,
|
| 212 |
+
int kTileCols_,
|
| 213 |
+
class MMAFrag_ = BaseMMAFrag<T, 8, 8>>
|
| 214 |
+
struct MMATile {
|
| 215 |
+
using MMAFrag_t = MMAFrag_;
|
| 216 |
+
using elem_type = T;
|
| 217 |
+
STEEL_CONST int kFragRows = MMAFrag_t::kFragRows;
|
| 218 |
+
STEEL_CONST int kFragCols = MMAFrag_t::kFragCols;
|
| 219 |
+
STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag;
|
| 220 |
+
|
| 221 |
+
STEEL_CONST int kTileRows = kTileRows_;
|
| 222 |
+
STEEL_CONST int kTileCols = kTileCols_;
|
| 223 |
+
|
| 224 |
+
STEEL_CONST int kRows = kTileRows * kFragRows;
|
| 225 |
+
STEEL_CONST int kCols = kTileCols * kFragCols;
|
| 226 |
+
|
| 227 |
+
STEEL_CONST int kNumFrags = kTileRows * kTileCols;
|
| 228 |
+
STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag;
|
| 229 |
+
|
| 230 |
+
typedef typename MMAFrag_t::mat_type mat_type;
|
| 231 |
+
typedef typename MMAFrag_t::frag_type frag_type;
|
| 232 |
+
|
| 233 |
+
frag_type val_frags[kNumFrags] = {frag_type(0)};
|
| 234 |
+
|
| 235 |
+
METAL_FUNC MMATile() thread {}
|
| 236 |
+
|
| 237 |
+
METAL_FUNC constexpr void clear() {
|
| 238 |
+
STEEL_PRAGMA_UNROLL
|
| 239 |
+
for (short i = 0; i < kNumFrags; ++i) {
|
| 240 |
+
val_frags[i] = frag_type(0);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) {
|
| 245 |
+
return val_frags[i * kTileCols + j];
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
METAL_FUNC constexpr const thread frag_type& frag_at(
|
| 249 |
+
const short i,
|
| 250 |
+
const short j) const {
|
| 251 |
+
return val_frags[i * kTileCols + j];
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
METAL_FUNC mat_type mat_at(const short i, const short j) {
|
| 255 |
+
mat_type val_mat;
|
| 256 |
+
STEEL_PRAGMA_UNROLL
|
| 257 |
+
for (short ii = 0; ii < kElemsPerFrag; ++ii) {
|
| 258 |
+
val_mat.thread_elements()[ii] = frag_at(i, j)[ii];
|
| 259 |
+
}
|
| 260 |
+
return val_mat;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
METAL_FUNC thread elem_type* elems() {
|
| 264 |
+
return reinterpret_cast<thread elem_type*>(val_frags);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
METAL_FUNC const thread elem_type* elems() const {
|
| 268 |
+
return reinterpret_cast<const thread elem_type*>(val_frags);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
| 272 |
+
METAL_FUNC void load(const threadgroup U* src) {
|
| 273 |
+
STEEL_PRAGMA_UNROLL
|
| 274 |
+
for (short i = 0; i < kTileRows; ++i) {
|
| 275 |
+
STEEL_PRAGMA_UNROLL
|
| 276 |
+
for (short j = 0; j < kTileCols; ++j) {
|
| 277 |
+
MMAFrag_t::load(
|
| 278 |
+
frag_at(i, j),
|
| 279 |
+
&(
|
| 280 |
+
src[(i * kFragRows) * w_x * str_x +
|
| 281 |
+
(j * kFragCols) * w_y * str_y]),
|
| 282 |
+
Int<str_x>{},
|
| 283 |
+
Int<str_y>{});
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
template <typename U, int w_x, int w_y, int str_x, int str_y>
|
| 289 |
+
METAL_FUNC void store(threadgroup U* dst) const {
|
| 290 |
+
STEEL_PRAGMA_UNROLL
|
| 291 |
+
for (short i = 0; i < kTileRows; ++i) {
|
| 292 |
+
STEEL_PRAGMA_UNROLL
|
| 293 |
+
for (short j = 0; j < kTileCols; ++j) {
|
| 294 |
+
MMAFrag_t::store(
|
| 295 |
+
frag_at(i, j),
|
| 296 |
+
&(
|
| 297 |
+
dst[(i * kFragRows) * w_x * str_x +
|
| 298 |
+
(j * kFragCols) * w_y * str_y]),
|
| 299 |
+
Int<str_x>{},
|
| 300 |
+
Int<str_y>{});
|
| 301 |
+
}
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
template <typename U, int w_x, int w_y>
|
| 306 |
+
METAL_FUNC void load(const device U* src, const int ld) {
|
| 307 |
+
STEEL_PRAGMA_UNROLL
|
| 308 |
+
for (short i = 0; i < kTileRows; ++i) {
|
| 309 |
+
STEEL_PRAGMA_UNROLL
|
| 310 |
+
for (short j = 0; j < kTileCols; ++j) {
|
| 311 |
+
MMAFrag_t::load(
|
| 312 |
+
frag_at(i, j),
|
| 313 |
+
&(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
| 314 |
+
ld,
|
| 315 |
+
Int<1>{});
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
template <typename U, int w_x, int w_y>
|
| 321 |
+
METAL_FUNC void store(device U* dst, const int ld) const {
|
| 322 |
+
STEEL_PRAGMA_UNROLL
|
| 323 |
+
for (short i = 0; i < kTileRows; ++i) {
|
| 324 |
+
STEEL_PRAGMA_UNROLL
|
| 325 |
+
for (short j = 0; j < kTileCols; ++j) {
|
| 326 |
+
MMAFrag_t::store(
|
| 327 |
+
frag_at(i, j),
|
| 328 |
+
&(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]),
|
| 329 |
+
ld,
|
| 330 |
+
Int<1>{});
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
template <typename U, int w_x, int w_y>
|
| 336 |
+
METAL_FUNC void
|
| 337 |
+
load_safe(const device U* src, const int ld, const short2 src_tile_dims) {
|
| 338 |
+
STEEL_PRAGMA_UNROLL
|
| 339 |
+
for (int i = 0; i < kTileRows; ++i) {
|
| 340 |
+
STEEL_PRAGMA_UNROLL
|
| 341 |
+
for (int j = 0; j < kTileCols; ++j) {
|
| 342 |
+
MMAFrag_t::load_safe(
|
| 343 |
+
frag_at(i, j),
|
| 344 |
+
src,
|
| 345 |
+
ld,
|
| 346 |
+
Int<1>{},
|
| 347 |
+
src_tile_dims.y,
|
| 348 |
+
src_tile_dims.x,
|
| 349 |
+
(i * kFragRows) * w_x,
|
| 350 |
+
(j * kFragCols) * w_y);
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
template <typename U, int w_x, int w_y>
|
| 356 |
+
METAL_FUNC void
|
| 357 |
+
store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const {
|
| 358 |
+
STEEL_PRAGMA_UNROLL
|
| 359 |
+
for (int i = 0; i < kTileRows; ++i) {
|
| 360 |
+
STEEL_PRAGMA_UNROLL
|
| 361 |
+
for (int j = 0; j < kTileCols; ++j) {
|
| 362 |
+
MMAFrag_t::store_safe(
|
| 363 |
+
frag_at(i, j),
|
| 364 |
+
dst,
|
| 365 |
+
ld,
|
| 366 |
+
Int<1>{},
|
| 367 |
+
dst_tile_dims.y,
|
| 368 |
+
dst_tile_dims.x,
|
| 369 |
+
(i * kFragRows) * w_x,
|
| 370 |
+
(j * kFragCols) * w_y);
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
template <typename U, int w_x, int w_y>
|
| 376 |
+
METAL_FUNC void store_slice(
|
| 377 |
+
device U* dst,
|
| 378 |
+
const int ld,
|
| 379 |
+
const short2 start,
|
| 380 |
+
const short2 stop) const {
|
| 381 |
+
STEEL_PRAGMA_UNROLL
|
| 382 |
+
for (int i = 0; i < kTileRows; ++i) {
|
| 383 |
+
STEEL_PRAGMA_UNROLL
|
| 384 |
+
for (int j = 0; j < kTileCols; ++j) {
|
| 385 |
+
MMAFrag_t::store_slice(
|
| 386 |
+
frag_at(i, j),
|
| 387 |
+
dst,
|
| 388 |
+
ld,
|
| 389 |
+
Int<1>{},
|
| 390 |
+
start.y,
|
| 391 |
+
stop.y,
|
| 392 |
+
start.x,
|
| 393 |
+
stop.x,
|
| 394 |
+
(i * kFragRows) * w_x,
|
| 395 |
+
(j * kFragCols) * w_y);
|
| 396 |
+
}
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
};
|
| 400 |
+
|
| 401 |
+
template <typename T, typename U, int M, int N, int K>
|
| 402 |
+
METAL_FUNC void tile_matmad(
|
| 403 |
+
thread MMATile<T, M, N>& D,
|
| 404 |
+
thread MMATile<U, M, K>& A,
|
| 405 |
+
thread MMATile<U, K, N>& B,
|
| 406 |
+
thread MMATile<T, M, N>& C) {
|
| 407 |
+
STEEL_PRAGMA_UNROLL
|
| 408 |
+
for (short m = 0; m < M; ++m) {
|
| 409 |
+
STEEL_PRAGMA_UNROLL
|
| 410 |
+
for (short n = 0; n < N; ++n) {
|
| 411 |
+
short n_serp = (m % 2) ? (N - 1 - n) : n;
|
| 412 |
+
STEEL_PRAGMA_UNROLL
|
| 413 |
+
for (short k = 0; k < K; ++k) {
|
| 414 |
+
MMATile<T, M, N>::MMAFrag_t::mma(
|
| 415 |
+
D.frag_at(m, n_serp),
|
| 416 |
+
A.frag_at(m, k),
|
| 417 |
+
B.frag_at(k, n_serp),
|
| 418 |
+
C.frag_at(m, n_serp));
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
template <
|
| 425 |
+
typename T,
|
| 426 |
+
typename U,
|
| 427 |
+
int BM,
|
| 428 |
+
int BN,
|
| 429 |
+
int BK,
|
| 430 |
+
int WM,
|
| 431 |
+
int WN,
|
| 432 |
+
bool transpose_a,
|
| 433 |
+
bool transpose_b,
|
| 434 |
+
short lda_tgp,
|
| 435 |
+
short ldb_tgp,
|
| 436 |
+
typename AccumType = float,
|
| 437 |
+
typename Epilogue = TransformNone<U, AccumType>>
|
| 438 |
+
struct BlockMMA {
|
| 439 |
+
// MMAFrag size
|
| 440 |
+
STEEL_CONST short kFragSize = 8;
|
| 441 |
+
using MMAFrag_acc_t = BaseMMAFrag<AccumType, kFragSize, kFragSize>;
|
| 442 |
+
|
| 443 |
+
// Warp tile simdgroup matrix strides along M
|
| 444 |
+
STEEL_CONST short TM_stride = kFragSize * WM;
|
| 445 |
+
// Warp tile simdgroup matrix strides along M
|
| 446 |
+
STEEL_CONST short TN_stride = kFragSize * WN;
|
| 447 |
+
|
| 448 |
+
// Warp tile size along M
|
| 449 |
+
STEEL_CONST short TM = BM / (kFragSize * WM);
|
| 450 |
+
// Warp tile size along N
|
| 451 |
+
STEEL_CONST short TN = BN / (kFragSize * WN);
|
| 452 |
+
|
| 453 |
+
// Threadgroup A strides
|
| 454 |
+
STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M
|
| 455 |
+
STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K
|
| 456 |
+
|
| 457 |
+
// Threadgroup B strides
|
| 458 |
+
STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K
|
| 459 |
+
STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N
|
| 460 |
+
|
| 461 |
+
// Threadgroup strides along K
|
| 462 |
+
STEEL_CONST short tile_stride_a = kFragSize * A_str_k;
|
| 463 |
+
STEEL_CONST short tile_stride_b = kFragSize * B_str_k;
|
| 464 |
+
|
| 465 |
+
// Simdgroup matrices
|
| 466 |
+
MMATile<AccumType, TM, 1, MMAFrag_acc_t> Atile;
|
| 467 |
+
MMATile<AccumType, 1, TN, MMAFrag_acc_t> Btile;
|
| 468 |
+
MMATile<AccumType, TM, TN, MMAFrag_acc_t> Ctile;
|
| 469 |
+
|
| 470 |
+
// Offsets within threadgroup
|
| 471 |
+
short sm;
|
| 472 |
+
short sn;
|
| 473 |
+
|
| 474 |
+
short As_offset;
|
| 475 |
+
short Bs_offset;
|
| 476 |
+
|
| 477 |
+
/* Constructor */
|
| 478 |
+
METAL_FUNC BlockMMA(
|
| 479 |
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
| 480 |
+
ushort simd_lane_id [[thread_index_in_simdgroup]]) {
|
| 481 |
+
// Determine thread position in simdgroup matrix
|
| 482 |
+
short tm = kFragSize * (simd_group_id / WN);
|
| 483 |
+
short tn = kFragSize * (simd_group_id % WN);
|
| 484 |
+
|
| 485 |
+
short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id);
|
| 486 |
+
sm = simd_coord.y;
|
| 487 |
+
sn = simd_coord.x;
|
| 488 |
+
|
| 489 |
+
// Determine thread and simdgroup offset
|
| 490 |
+
As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K
|
| 491 |
+
Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N
|
| 492 |
+
|
| 493 |
+
sm += tm;
|
| 494 |
+
sn += tn;
|
| 495 |
+
}
|
| 496 |
+
|
| 497 |
+
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
| 498 |
+
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
| 499 |
+
// Adjust for simdgroup and thread location
|
| 500 |
+
As += As_offset;
|
| 501 |
+
Bs += Bs_offset;
|
| 502 |
+
|
| 503 |
+
// Iterate over BK in blocks of kFragSize
|
| 504 |
+
STEEL_PRAGMA_UNROLL
|
| 505 |
+
for (short kk = 0; kk < BK; kk += kFragSize) {
|
| 506 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 507 |
+
|
| 508 |
+
Atile.template load<T, WM, 1, A_str_m, A_str_k>(As);
|
| 509 |
+
|
| 510 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 511 |
+
|
| 512 |
+
Btile.template load<T, 1, WN, B_str_k, B_str_n>(Bs);
|
| 513 |
+
|
| 514 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 515 |
+
|
| 516 |
+
tile_matmad(Ctile, Atile, Btile, Ctile);
|
| 517 |
+
|
| 518 |
+
// Progress to next simdgroup tile
|
| 519 |
+
As += tile_stride_a;
|
| 520 |
+
Bs += tile_stride_b;
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
/* Store results from simdgroup_matrix results into device memory */
|
| 525 |
+
METAL_FUNC void store_result(device U* D, const int ldd) {
|
| 526 |
+
// Apply epilogue
|
| 527 |
+
STEEL_PRAGMA_UNROLL
|
| 528 |
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
| 529 |
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
// Adjust for simdgroup and thread location
|
| 533 |
+
D += sm * ldd + sn;
|
| 534 |
+
|
| 535 |
+
Ctile.template store<U, WM, WN>(D, ldd);
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
METAL_FUNC void
|
| 539 |
+
store_result_slice(device U* D, const int ldd, short2 start, short2 stop) {
|
| 540 |
+
// Apply epilogue
|
| 541 |
+
STEEL_PRAGMA_UNROLL
|
| 542 |
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
| 543 |
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
D += sm * ldd + sn;
|
| 547 |
+
start -= short2(sn, sm);
|
| 548 |
+
stop -= short2(sn, sm);
|
| 549 |
+
|
| 550 |
+
// TODO: Check the start as well
|
| 551 |
+
if (stop.y <= 0 || stop.x <= 0) {
|
| 552 |
+
return;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
Ctile.template store_slice<U, WM, WN>(D, ldd, start, stop);
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
METAL_FUNC void
|
| 559 |
+
store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) {
|
| 560 |
+
// Apply epilogue
|
| 561 |
+
STEEL_PRAGMA_UNROLL
|
| 562 |
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
| 563 |
+
Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]);
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
// Adjust for simdgroup and thread location
|
| 567 |
+
D += sm * ldd + sn;
|
| 568 |
+
dst_tile_dims -= short2(sn, sm);
|
| 569 |
+
|
| 570 |
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
| 571 |
+
return;
|
| 572 |
+
|
| 573 |
+
Ctile.template store_safe<U, WM, WN>(D, ldd, dst_tile_dims);
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
/* Apply epilogue */
|
| 577 |
+
template <typename UnaryEpilogue>
|
| 578 |
+
METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) {
|
| 579 |
+
// Loop over all simdgroup tiles
|
| 580 |
+
STEEL_PRAGMA_UNROLL
|
| 581 |
+
for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) {
|
| 582 |
+
Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]);
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
/* Apply epilogue */
|
| 587 |
+
template <typename BinaryEpilogue>
|
| 588 |
+
METAL_FUNC void apply_epilogue(
|
| 589 |
+
const device U* C,
|
| 590 |
+
const int ldc,
|
| 591 |
+
const int fdc,
|
| 592 |
+
thread const BinaryEpilogue& epilogue_op) {
|
| 593 |
+
// Adjust for simdgroup and thread location
|
| 594 |
+
C += (sm)*ldc + (sn)*fdc;
|
| 595 |
+
|
| 596 |
+
// Loop over all simdgroup tiles
|
| 597 |
+
STEEL_PRAGMA_UNROLL
|
| 598 |
+
for (short i = 0; i < TM; i++) {
|
| 599 |
+
STEEL_PRAGMA_UNROLL
|
| 600 |
+
for (short j = 0; j < TN; j++) {
|
| 601 |
+
// Get accumulated result and associated offset in C
|
| 602 |
+
thread auto& accum = Ctile.frag_at(i, j);
|
| 603 |
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
| 604 |
+
|
| 605 |
+
// Apply epilogue
|
| 606 |
+
STEEL_PRAGMA_UNROLL
|
| 607 |
+
for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) {
|
| 608 |
+
accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
| 609 |
+
}
|
| 610 |
+
}
|
| 611 |
+
}
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
/* Apply epilogue */
|
| 615 |
+
template <typename BinaryEpilogue>
|
| 616 |
+
METAL_FUNC void apply_epilogue_safe(
|
| 617 |
+
const device U* C,
|
| 618 |
+
const int ldc,
|
| 619 |
+
const int fdc,
|
| 620 |
+
short2 dst_tile_dims,
|
| 621 |
+
thread const BinaryEpilogue& epilogue_op) {
|
| 622 |
+
// Adjust for simdgroup and thread location
|
| 623 |
+
C += (sm)*ldc + (sn)*fdc;
|
| 624 |
+
dst_tile_dims -= short2(sn, sm);
|
| 625 |
+
|
| 626 |
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
| 627 |
+
return;
|
| 628 |
+
|
| 629 |
+
// Loop over all simdgroup tiles
|
| 630 |
+
STEEL_PRAGMA_UNROLL
|
| 631 |
+
for (short i = 0; i < TM; i++) {
|
| 632 |
+
STEEL_PRAGMA_UNROLL
|
| 633 |
+
for (short j = 0; j < TN; j++) {
|
| 634 |
+
// Get accumulated result and associated offset in C
|
| 635 |
+
thread auto& accum = Ctile.frag_at(i, j);
|
| 636 |
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
| 637 |
+
|
| 638 |
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
| 639 |
+
|
| 640 |
+
// Read C
|
| 641 |
+
U c_elems[kelems] = {0};
|
| 642 |
+
|
| 643 |
+
STEEL_PRAGMA_UNROLL
|
| 644 |
+
for (short k = 0; k < kelems; k++) {
|
| 645 |
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
| 646 |
+
c_elems[k] = C[offset_c + k * fdc];
|
| 647 |
+
}
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
// Apply epilogue
|
| 651 |
+
STEEL_PRAGMA_UNROLL
|
| 652 |
+
for (short k = 0; k < kelems; k++) {
|
| 653 |
+
accum[k] = epilogue_op.apply(accum[k], c_elems[k]);
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
}
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
/* Store results from simdgroup_matrix results into device memory */
|
| 660 |
+
METAL_FUNC void store_result(
|
| 661 |
+
device U* D,
|
| 662 |
+
const int ldd,
|
| 663 |
+
const device U* C,
|
| 664 |
+
const int ldc,
|
| 665 |
+
const int fdc,
|
| 666 |
+
thread const Epilogue& epilogue_op) const {
|
| 667 |
+
// Adjust for simdgroup and thread location
|
| 668 |
+
C += (sm)*ldc + (sn)*fdc;
|
| 669 |
+
D += (sm)*ldd + sn;
|
| 670 |
+
|
| 671 |
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
| 672 |
+
|
| 673 |
+
// Loop over all simdgroup tiles
|
| 674 |
+
STEEL_PRAGMA_UNROLL
|
| 675 |
+
for (short i = 0; i < TM; i++) {
|
| 676 |
+
STEEL_PRAGMA_UNROLL
|
| 677 |
+
for (short j = 0; j < TN; j++) {
|
| 678 |
+
// Get accumulated result and associated offset in C
|
| 679 |
+
thread const auto& accum = Ctile.frag_at(i, j);
|
| 680 |
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
| 681 |
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
| 682 |
+
|
| 683 |
+
// Apply epilogue
|
| 684 |
+
STEEL_PRAGMA_UNROLL
|
| 685 |
+
for (short k = 0; k < kelems; k++) {
|
| 686 |
+
D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
METAL_FUNC void store_result_safe(
|
| 693 |
+
device U* D,
|
| 694 |
+
const int ldd,
|
| 695 |
+
const device U* C,
|
| 696 |
+
const int ldc,
|
| 697 |
+
const int fdc,
|
| 698 |
+
short2 dst_tile_dims,
|
| 699 |
+
thread const Epilogue& epilogue_op) const {
|
| 700 |
+
// Adjust for simdgroup and thread location
|
| 701 |
+
C += (sm)*ldc + (sn)*fdc;
|
| 702 |
+
D += (sm)*ldd + sn;
|
| 703 |
+
dst_tile_dims -= short2(sn, sm);
|
| 704 |
+
|
| 705 |
+
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
|
| 706 |
+
return;
|
| 707 |
+
|
| 708 |
+
constexpr short kelems = decltype(Ctile)::kElemsPerFrag;
|
| 709 |
+
|
| 710 |
+
STEEL_PRAGMA_UNROLL
|
| 711 |
+
for (int i = 0; i < TM; i++) {
|
| 712 |
+
if (i * TM_stride < dst_tile_dims.y) {
|
| 713 |
+
STEEL_PRAGMA_UNROLL
|
| 714 |
+
for (int j = 0; j < TN; j++) {
|
| 715 |
+
// Get accumulated result and associated offset in C
|
| 716 |
+
thread const auto& accum = Ctile.frag_at(i, j);
|
| 717 |
+
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
| 718 |
+
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
| 719 |
+
|
| 720 |
+
// Apply epilogue
|
| 721 |
+
STEEL_PRAGMA_UNROLL
|
| 722 |
+
for (short k = 0; k < kelems; k++) {
|
| 723 |
+
if ((j * TN_stride + k) < dst_tile_dims.x) {
|
| 724 |
+
D[offset_d + k] =
|
| 725 |
+
epilogue_op.apply(accum[k], C[offset_c + k * fdc]);
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
}
|
| 729 |
+
}
|
| 730 |
+
}
|
| 731 |
+
}
|
| 732 |
+
};
|
| 733 |
+
|
| 734 |
+
} // namespace steel
|
| 735 |
+
} // namespace mlx
|
bitsandbytes_mps/gemm/params.h
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 6 |
+
// GEMM param classes
|
| 7 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
namespace mlx {
|
| 10 |
+
namespace steel {
|
| 11 |
+
|
| 12 |
+
struct GEMMParams {
|
| 13 |
+
const int M;
|
| 14 |
+
const int N;
|
| 15 |
+
const int K;
|
| 16 |
+
|
| 17 |
+
const int lda;
|
| 18 |
+
const int ldb;
|
| 19 |
+
const int ldd;
|
| 20 |
+
|
| 21 |
+
const int tiles_n;
|
| 22 |
+
const int tiles_m;
|
| 23 |
+
|
| 24 |
+
const int64_t batch_stride_a;
|
| 25 |
+
const int64_t batch_stride_b;
|
| 26 |
+
const int64_t batch_stride_d;
|
| 27 |
+
|
| 28 |
+
const int swizzle_log;
|
| 29 |
+
const int gemm_k_iterations_aligned;
|
| 30 |
+
|
| 31 |
+
const int batch_ndim;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
struct GEMMSpiltKParams {
|
| 35 |
+
const int M;
|
| 36 |
+
const int N;
|
| 37 |
+
const int K;
|
| 38 |
+
|
| 39 |
+
const int lda;
|
| 40 |
+
const int ldb;
|
| 41 |
+
const int ldc;
|
| 42 |
+
|
| 43 |
+
const int tiles_n;
|
| 44 |
+
const int tiles_m;
|
| 45 |
+
|
| 46 |
+
const int split_k_partitions;
|
| 47 |
+
const int split_k_partition_stride;
|
| 48 |
+
const int split_k_partition_size;
|
| 49 |
+
|
| 50 |
+
const int gemm_k_iterations_aligned;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
struct GEMMAddMMParams {
|
| 54 |
+
const int ldc;
|
| 55 |
+
const int fdc;
|
| 56 |
+
|
| 57 |
+
const int64_t batch_stride_c;
|
| 58 |
+
|
| 59 |
+
const float alpha;
|
| 60 |
+
const float beta;
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
} // namespace steel
|
| 64 |
+
} // namespace mlx
|
bitsandbytes_mps/gemm/transforms.h
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include "gemm/utils.h"
|
| 6 |
+
|
| 7 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
// Transforms and Epilogues
|
| 9 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 10 |
+
|
| 11 |
+
namespace mlx {
|
| 12 |
+
namespace steel {
|
| 13 |
+
|
| 14 |
+
template <typename OutT, typename InT>
|
| 15 |
+
struct TransformNone {
|
| 16 |
+
static METAL_FUNC OutT apply(InT x) {
|
| 17 |
+
return static_cast<OutT>(x);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
static METAL_FUNC OutT apply(InT x, OutT) {
|
| 21 |
+
return static_cast<OutT>(x);
|
| 22 |
+
}
|
| 23 |
+
};
|
| 24 |
+
|
| 25 |
+
template <typename OutT, typename InT>
|
| 26 |
+
struct TransformAdd {
|
| 27 |
+
TransformAdd(const float, const float) {}
|
| 28 |
+
|
| 29 |
+
static METAL_FUNC OutT apply(InT x) {
|
| 30 |
+
return static_cast<OutT>(x);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
static METAL_FUNC OutT apply(InT x, OutT c) {
|
| 34 |
+
return static_cast<OutT>(x) + c;
|
| 35 |
+
}
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
template <typename OutT, typename InT>
|
| 39 |
+
struct TransformAxpby {
|
| 40 |
+
const float alpha;
|
| 41 |
+
const float beta;
|
| 42 |
+
|
| 43 |
+
TransformAxpby(const float alpha_, const float beta_)
|
| 44 |
+
: alpha(alpha_), beta(beta_) {}
|
| 45 |
+
|
| 46 |
+
static METAL_FUNC OutT apply(InT x) {
|
| 47 |
+
return static_cast<OutT>(x);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
METAL_FUNC OutT apply(InT x, OutT c) const {
|
| 51 |
+
return static_cast<OutT>(
|
| 52 |
+
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
|
| 53 |
+
}
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
template <typename T>
|
| 57 |
+
struct AccumHelper {
|
| 58 |
+
typedef float accum_type;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
struct BlockSwizzle {
|
| 62 |
+
static METAL_FUNC int2
|
| 63 |
+
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
| 64 |
+
const int tid_x = (tid.x) >> swizzle_log;
|
| 65 |
+
const int tid_y =
|
| 66 |
+
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
| 67 |
+
return int2(tid_x, tid_y);
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
} // namespace steel
|
| 72 |
+
} // namespace mlx
|
bitsandbytes_mps/gemm/utils.h
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
|
| 7 |
+
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
| 8 |
+
uint elem,
|
| 9 |
+
constant const int* shape,
|
| 10 |
+
constant const int64_t* a_strides,
|
| 11 |
+
constant const int64_t* b_strides,
|
| 12 |
+
int ndim) {
|
| 13 |
+
ulong loc_a{0};
|
| 14 |
+
ulong loc_b{0};
|
| 15 |
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
| 16 |
+
int pos_in_dim = (elem % shape[i]);
|
| 17 |
+
elem /= shape[i];
|
| 18 |
+
loc_a += pos_in_dim * a_strides[i];
|
| 19 |
+
loc_b += pos_in_dim * b_strides[i];
|
| 20 |
+
}
|
| 21 |
+
return ulong2(loc_a, loc_b);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
METAL_FUNC ulong3 elem_to_loc_broadcast(
|
| 25 |
+
uint elem,
|
| 26 |
+
constant const int* shape,
|
| 27 |
+
constant const int64_t* a_strides,
|
| 28 |
+
constant const int64_t* b_strides,
|
| 29 |
+
constant const int64_t* c_strides,
|
| 30 |
+
int ndim) {
|
| 31 |
+
ulong loc_a{0};
|
| 32 |
+
ulong loc_b{0};
|
| 33 |
+
ulong loc_c{0};
|
| 34 |
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
| 35 |
+
int pos_in_dim = (elem % shape[i]);
|
| 36 |
+
elem /= shape[i];
|
| 37 |
+
loc_a += pos_in_dim * a_strides[i];
|
| 38 |
+
loc_b += pos_in_dim * b_strides[i];
|
| 39 |
+
loc_c += pos_in_dim * c_strides[i];
|
| 40 |
+
}
|
| 41 |
+
return ulong3(loc_a, loc_b, loc_c);
|
| 42 |
+
}
|
bitsandbytes_mps/gemm/utils/integral_constant.h
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
#include "gemm/utils/type_traits.h"
|
| 7 |
+
|
| 8 |
+
#pragma METAL internals : enable
|
| 9 |
+
|
| 10 |
+
namespace mlx {
|
| 11 |
+
namespace steel {
|
| 12 |
+
|
| 13 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 14 |
+
// Integral constant with casting
|
| 15 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 16 |
+
|
| 17 |
+
template <typename T, T v>
|
| 18 |
+
struct integral_constant {
|
| 19 |
+
static constexpr constant T value = v;
|
| 20 |
+
using value_type = T;
|
| 21 |
+
using type = integral_constant;
|
| 22 |
+
|
| 23 |
+
METAL_FUNC constexpr operator value_type() const noexcept {
|
| 24 |
+
return value;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
// METAL_FUNC constexpr value_type operator()() const noexcept {
|
| 28 |
+
// return value;
|
| 29 |
+
// }
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
template <bool B>
|
| 33 |
+
using bool_constant = integral_constant<bool, B>;
|
| 34 |
+
using true_type = bool_constant<true>;
|
| 35 |
+
using false_type = bool_constant<false>;
|
| 36 |
+
|
| 37 |
+
template <class T>
|
| 38 |
+
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
|
| 39 |
+
|
| 40 |
+
template <class T, T v>
|
| 41 |
+
struct is_integral<integral_constant<T, v>>
|
| 42 |
+
: bool_constant<metal::is_integral<T>::value> {};
|
| 43 |
+
|
| 44 |
+
template <typename T>
|
| 45 |
+
constexpr constant bool is_integral_v = is_integral<T>::value;
|
| 46 |
+
|
| 47 |
+
template <int val>
|
| 48 |
+
using Int = integral_constant<int, val>;
|
| 49 |
+
|
| 50 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
// Binary Operators on Integral constants
|
| 52 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
#define integral_const_binop(__op__, __operator__) \
|
| 55 |
+
template <typename T, T tv, typename U, U uv> \
|
| 56 |
+
METAL_FUNC constexpr auto __operator__( \
|
| 57 |
+
integral_constant<T, tv>, integral_constant<U, uv>) { \
|
| 58 |
+
constexpr auto res = tv __op__ uv; \
|
| 59 |
+
return integral_constant<decltype(res), res>{}; \
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
integral_const_binop(+, operator+);
|
| 63 |
+
integral_const_binop(-, operator-);
|
| 64 |
+
integral_const_binop(*, operator*);
|
| 65 |
+
integral_const_binop(/, operator/);
|
| 66 |
+
|
| 67 |
+
integral_const_binop(==, operator==);
|
| 68 |
+
integral_const_binop(!=, operator!=);
|
| 69 |
+
integral_const_binop(<, operator<);
|
| 70 |
+
integral_const_binop(>, operator>);
|
| 71 |
+
integral_const_binop(<=, operator<=);
|
| 72 |
+
integral_const_binop(>=, operator>=);
|
| 73 |
+
|
| 74 |
+
integral_const_binop(&&, operator&&);
|
| 75 |
+
integral_const_binop(||, operator||);
|
| 76 |
+
|
| 77 |
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
| 78 |
+
METAL_FUNC constexpr auto operator||(true_type, T) {
|
| 79 |
+
return true_type{};
|
| 80 |
+
}
|
| 81 |
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
| 82 |
+
METAL_FUNC constexpr auto operator||(T, true_type) {
|
| 83 |
+
return true_type{};
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
| 87 |
+
METAL_FUNC constexpr auto operator&&(false_type, T) {
|
| 88 |
+
return false_type{};
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
| 92 |
+
METAL_FUNC constexpr auto operator&&(T, false_type) {
|
| 93 |
+
return false_type{};
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
// Dispatch utilities
|
| 97 |
+
template <typename F>
|
| 98 |
+
void dispatch_bool(bool v, F f) {
|
| 99 |
+
if (v) {
|
| 100 |
+
f(true_type{});
|
| 101 |
+
} else {
|
| 102 |
+
f(false_type{});
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
template <int start, int stop, int step, typename F>
|
| 107 |
+
constexpr void const_for_loop(F f) {
|
| 108 |
+
if constexpr (start < stop) {
|
| 109 |
+
constexpr auto idx = Int<start>{};
|
| 110 |
+
f(idx);
|
| 111 |
+
const_for_loop<start + step, stop, step, F>(f);
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
#undef integral_const_binop
|
| 116 |
+
|
| 117 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 118 |
+
// Reduction operators
|
| 119 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 120 |
+
|
| 121 |
+
template <typename T>
|
| 122 |
+
METAL_FUNC constexpr T sum(T x) {
|
| 123 |
+
return x;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
template <typename T, typename... Us>
|
| 127 |
+
METAL_FUNC constexpr auto sum(T x, Us... us) {
|
| 128 |
+
return x + sum(us...);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
} // namespace steel
|
| 132 |
+
} // namespace mlx
|
| 133 |
+
|
| 134 |
+
#pragma METAL internals : disable
|
bitsandbytes_mps/gemm/utils/type_traits.h
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_stdlib>
|
| 6 |
+
|
| 7 |
+
#pragma METAL internals : enable
|
| 8 |
+
|
| 9 |
+
namespace metal {
|
| 10 |
+
|
| 11 |
+
template <typename T>
|
| 12 |
+
struct is_empty : metal::bool_constant<__is_empty(T)> {};
|
| 13 |
+
|
| 14 |
+
#ifdef __cpp_variable_templates
|
| 15 |
+
template <typename T>
|
| 16 |
+
constexpr constant bool is_empty_v = is_empty<T>::value;
|
| 17 |
+
#endif
|
| 18 |
+
|
| 19 |
+
template <typename... Ts>
|
| 20 |
+
struct make_void {
|
| 21 |
+
typedef void type;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
template <typename... Ts>
|
| 25 |
+
using void_t = typename make_void<Ts...>::type;
|
| 26 |
+
|
| 27 |
+
template <class T>
|
| 28 |
+
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
|
| 29 |
+
|
| 30 |
+
template <typename T>
|
| 31 |
+
struct pointer_element {};
|
| 32 |
+
|
| 33 |
+
template <typename T>
|
| 34 |
+
struct pointer_element<thread T*> {
|
| 35 |
+
using type = remove_cv_t<T>;
|
| 36 |
+
};
|
| 37 |
+
template <typename T>
|
| 38 |
+
struct pointer_element<device T*> {
|
| 39 |
+
using type = remove_cv_t<T>;
|
| 40 |
+
};
|
| 41 |
+
template <typename T>
|
| 42 |
+
struct pointer_element<constant T*> {
|
| 43 |
+
using type = remove_cv_t<T>;
|
| 44 |
+
};
|
| 45 |
+
template <typename T>
|
| 46 |
+
struct pointer_element<threadgroup T*> {
|
| 47 |
+
using type = remove_cv_t<T>;
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
template <typename T>
|
| 51 |
+
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
|
| 52 |
+
|
| 53 |
+
} // namespace metal
|
| 54 |
+
|
| 55 |
+
#pragma METAL internals : disable
|
bitsandbytes_mps/quantized_utils.h
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023-2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include <metal_simdgroup>
|
| 4 |
+
#include <metal_stdlib>
|
| 5 |
+
|
| 6 |
+
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
| 7 |
+
METAL_FUNC void gemm_loop_aligned(
|
| 8 |
+
threadgroup T* As,
|
| 9 |
+
threadgroup T* Bs,
|
| 10 |
+
thread mma_t& mma_op,
|
| 11 |
+
thread loader_a_t& loader_a,
|
| 12 |
+
thread loader_b_t& loader_b,
|
| 13 |
+
const int k_iterations) {
|
| 14 |
+
for (int k = 0; k < k_iterations; k++) {
|
| 15 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 16 |
+
|
| 17 |
+
// Load elements into threadgroup memory
|
| 18 |
+
loader_a.load_unsafe();
|
| 19 |
+
loader_b.load_unsafe();
|
| 20 |
+
|
| 21 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 22 |
+
|
| 23 |
+
// Multiply and accumulate threadgroup elements
|
| 24 |
+
mma_op.mma(As, Bs);
|
| 25 |
+
|
| 26 |
+
// Prepare for next iteration
|
| 27 |
+
loader_a.next();
|
| 28 |
+
loader_b.next();
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
template <
|
| 33 |
+
bool rows_aligned,
|
| 34 |
+
bool cols_aligned,
|
| 35 |
+
bool transpose,
|
| 36 |
+
typename T,
|
| 37 |
+
typename mma_t,
|
| 38 |
+
typename loader_a_t,
|
| 39 |
+
typename loader_b_t>
|
| 40 |
+
METAL_FUNC void gemm_loop_unaligned(
|
| 41 |
+
threadgroup T* As,
|
| 42 |
+
threadgroup T* Bs,
|
| 43 |
+
thread mma_t& mma_op,
|
| 44 |
+
thread loader_a_t& loader_a,
|
| 45 |
+
thread loader_b_t& loader_b,
|
| 46 |
+
const int k_iterations,
|
| 47 |
+
const short tgp_bm,
|
| 48 |
+
const short tgp_bn,
|
| 49 |
+
const short tgp_bk) {
|
| 50 |
+
for (int k = 0; k < k_iterations; k++) {
|
| 51 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 52 |
+
|
| 53 |
+
// Load elements into threadgroup memory
|
| 54 |
+
if (rows_aligned) {
|
| 55 |
+
loader_a.load_unsafe();
|
| 56 |
+
} else {
|
| 57 |
+
loader_a.load_safe(short2(tgp_bk, tgp_bm));
|
| 58 |
+
}
|
| 59 |
+
if (cols_aligned) {
|
| 60 |
+
loader_b.load_unsafe();
|
| 61 |
+
} else {
|
| 62 |
+
loader_b.load_safe(
|
| 63 |
+
transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 67 |
+
|
| 68 |
+
// Multiply and accumulate threadgroup elements
|
| 69 |
+
mma_op.mma(As, Bs);
|
| 70 |
+
|
| 71 |
+
// Prepare for next iteration
|
| 72 |
+
loader_a.next();
|
| 73 |
+
loader_b.next();
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
|
| 78 |
+
METAL_FUNC void gemm_loop_finalize(
|
| 79 |
+
threadgroup T* As,
|
| 80 |
+
threadgroup T* Bs,
|
| 81 |
+
thread mma_t& mma_op,
|
| 82 |
+
thread loader_a_t& loader_a,
|
| 83 |
+
thread loader_b_t& loader_b,
|
| 84 |
+
const short2 tile_a,
|
| 85 |
+
const short2 tile_b) {
|
| 86 |
+
loader_a.load_safe(tile_a);
|
| 87 |
+
loader_b.load_safe(tile_b);
|
| 88 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 89 |
+
mma_op.mma(As, Bs);
|
| 90 |
+
}
|
bitsandbytes_mps/utils.h
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023-2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <metal_math>
|
| 6 |
+
|
| 7 |
+
#include "bf16.h"
|
| 8 |
+
#include "defines.h"
|
| 9 |
+
|
| 10 |
+
typedef half float16_t;
|
| 11 |
+
|
| 12 |
+
// Work per thread values for different types. The values here are expected to
|
| 13 |
+
// match get_work_per_thread in mlx/backend/metal/utils.h
|
| 14 |
+
template <typename U>
|
| 15 |
+
struct WorkPerThread {
|
| 16 |
+
static_assert(sizeof(U) <= 8, "Type too large");
|
| 17 |
+
static constexpr int constant n = 8 / sizeof(U);
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 21 |
+
// Type limits utils
|
| 22 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 23 |
+
|
| 24 |
+
template <typename U>
|
| 25 |
+
struct Limits {
|
| 26 |
+
static const constant U max = metal::numeric_limits<U>::max();
|
| 27 |
+
static const constant U min = metal::numeric_limits<U>::min();
|
| 28 |
+
static const constant U finite_max = metal::numeric_limits<U>::max();
|
| 29 |
+
static const constant U finite_min = metal::numeric_limits<U>::min();
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
#define instantiate_default_limit(type) \
|
| 33 |
+
template <> \
|
| 34 |
+
struct Limits<type> { \
|
| 35 |
+
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
| 36 |
+
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
| 37 |
+
static constexpr constant type finite_max = \
|
| 38 |
+
metal::numeric_limits<type>::max(); \
|
| 39 |
+
static constexpr constant type finite_min = \
|
| 40 |
+
metal::numeric_limits<type>::min(); \
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
instantiate_default_limit(uint8_t);
|
| 44 |
+
instantiate_default_limit(uint16_t);
|
| 45 |
+
instantiate_default_limit(uint32_t);
|
| 46 |
+
instantiate_default_limit(uint64_t);
|
| 47 |
+
instantiate_default_limit(int8_t);
|
| 48 |
+
instantiate_default_limit(int16_t);
|
| 49 |
+
instantiate_default_limit(int32_t);
|
| 50 |
+
instantiate_default_limit(int64_t);
|
| 51 |
+
|
| 52 |
+
#define instantiate_float_limit(type) \
|
| 53 |
+
template <> \
|
| 54 |
+
struct Limits<type> { \
|
| 55 |
+
static constexpr constant type max = \
|
| 56 |
+
metal::numeric_limits<type>::infinity(); \
|
| 57 |
+
static constexpr constant type min = \
|
| 58 |
+
-metal::numeric_limits<type>::infinity(); \
|
| 59 |
+
static constexpr constant type finite_max = \
|
| 60 |
+
metal::numeric_limits<type>::max(); \
|
| 61 |
+
static constexpr constant type finite_min = \
|
| 62 |
+
-metal::numeric_limits<type>::max(); \
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
instantiate_float_limit(half);
|
| 66 |
+
instantiate_float_limit(float);
|
| 67 |
+
instantiate_float_limit(bfloat16_t);
|
| 68 |
+
|
| 69 |
+
template <>
|
| 70 |
+
struct Limits<bool> {
|
| 71 |
+
static constexpr constant bool max = true;
|
| 72 |
+
static constexpr constant bool min = false;
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
// complex64_t specialization removed - not needed for BnB kernels
|
| 76 |
+
|
| 77 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 78 |
+
// Indexing utils
|
| 79 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 80 |
+
|
| 81 |
+
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
| 82 |
+
|
| 83 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 84 |
+
// Single Array with generic dims
|
| 85 |
+
|
| 86 |
+
template <typename IdxT = int64_t>
|
| 87 |
+
METAL_FUNC IdxT elem_to_loc(
|
| 88 |
+
IdxT elem,
|
| 89 |
+
constant const int* shape,
|
| 90 |
+
constant const int64_t* strides,
|
| 91 |
+
int ndim) {
|
| 92 |
+
IdxT loc = 0;
|
| 93 |
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
| 94 |
+
loc += (elem % shape[i]) * IdxT(strides[i]);
|
| 95 |
+
elem /= shape[i];
|
| 96 |
+
}
|
| 97 |
+
return loc;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// Non templated version to handle arbitrary dims
|
| 101 |
+
template <typename IdxT = int64_t>
|
| 102 |
+
METAL_FUNC IdxT elem_to_loc(
|
| 103 |
+
uint3 elem,
|
| 104 |
+
constant const int* shape,
|
| 105 |
+
constant const int64_t* strides,
|
| 106 |
+
int ndim) {
|
| 107 |
+
IdxT loc =
|
| 108 |
+
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
|
| 109 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
| 110 |
+
loc += (elem.z % shape[d]) * IdxT(strides[d]);
|
| 111 |
+
elem.z /= shape[d];
|
| 112 |
+
}
|
| 113 |
+
return loc;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 117 |
+
// Single Array with fixed N dims
|
| 118 |
+
|
| 119 |
+
template <typename IdxT = int64_t>
|
| 120 |
+
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
|
| 121 |
+
return elem * IdxT(stride);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
template <typename IdxT = int64_t>
|
| 125 |
+
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
|
| 126 |
+
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template <typename IdxT = int64_t>
|
| 130 |
+
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
|
| 131 |
+
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
|
| 132 |
+
elem.z * IdxT(strides[0]);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 136 |
+
// Multiple Arrays with generic dims
|
| 137 |
+
|
| 138 |
+
template <typename IdxT = int64_t>
|
| 139 |
+
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
|
| 140 |
+
uint3 elem,
|
| 141 |
+
constant const int* shape,
|
| 142 |
+
constant const int64_t* a_strides,
|
| 143 |
+
constant const int64_t* b_strides,
|
| 144 |
+
int ndim) {
|
| 145 |
+
vec<IdxT, 2> loc = {
|
| 146 |
+
IdxT(
|
| 147 |
+
elem.x * IdxT(a_strides[ndim - 1]) +
|
| 148 |
+
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
|
| 149 |
+
IdxT(
|
| 150 |
+
elem.x * IdxT(b_strides[ndim - 1]) +
|
| 151 |
+
elem.y * IdxT(b_strides[ndim - 2]))};
|
| 152 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
| 153 |
+
uint l = elem.z % shape[d];
|
| 154 |
+
loc.x += l * IdxT(a_strides[d]);
|
| 155 |
+
loc.y += l * IdxT(b_strides[d]);
|
| 156 |
+
elem.z /= shape[d];
|
| 157 |
+
}
|
| 158 |
+
return loc;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
template <typename IdxT = int64_t>
|
| 162 |
+
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
|
| 163 |
+
uint3 elem,
|
| 164 |
+
constant const int* shape,
|
| 165 |
+
constant const int64_t* a_strides,
|
| 166 |
+
constant const int64_t* b_strides,
|
| 167 |
+
constant const int64_t* c_strides,
|
| 168 |
+
int ndim) {
|
| 169 |
+
vec<IdxT, 3> loc = {
|
| 170 |
+
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
|
| 171 |
+
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
|
| 172 |
+
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
|
| 173 |
+
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
|
| 174 |
+
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
|
| 175 |
+
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
|
| 176 |
+
for (int d = ndim - 3; d >= 0; --d) {
|
| 177 |
+
uint l = elem.z % shape[d];
|
| 178 |
+
loc.x += l * IdxT(a_strides[d]);
|
| 179 |
+
loc.y += l * IdxT(b_strides[d]);
|
| 180 |
+
loc.z += l * IdxT(c_strides[d]);
|
| 181 |
+
elem.z /= shape[d];
|
| 182 |
+
}
|
| 183 |
+
return loc;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 187 |
+
// Elem to loc in a loop utils
|
| 188 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 189 |
+
|
| 190 |
+
template <int DIM, typename OffsetT = size_t, bool General = true>
|
| 191 |
+
struct LoopedElemToLoc {
|
| 192 |
+
int dim;
|
| 193 |
+
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
|
| 194 |
+
OffsetT offset{0};
|
| 195 |
+
int index{0};
|
| 196 |
+
|
| 197 |
+
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
|
| 198 |
+
|
| 199 |
+
void next(const constant int* shape, const constant int64_t* strides) {
|
| 200 |
+
if (dim == 0) {
|
| 201 |
+
return;
|
| 202 |
+
}
|
| 203 |
+
index++;
|
| 204 |
+
offset += OffsetT(strides[dim - 1]);
|
| 205 |
+
if (index >= shape[dim - 1]) {
|
| 206 |
+
index = 0;
|
| 207 |
+
inner_looper.next(shape, strides);
|
| 208 |
+
offset = inner_looper.offset;
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
| 213 |
+
if (dim == 0) {
|
| 214 |
+
return;
|
| 215 |
+
}
|
| 216 |
+
index += n;
|
| 217 |
+
offset += n * OffsetT(strides[dim - 1]);
|
| 218 |
+
|
| 219 |
+
if (index >= shape[dim - 1]) {
|
| 220 |
+
int extra = index - shape[dim - 1];
|
| 221 |
+
if (extra >= shape[dim - 1]) {
|
| 222 |
+
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
|
| 223 |
+
extra = extra % shape[dim - 1];
|
| 224 |
+
} else {
|
| 225 |
+
inner_looper.next(shape, strides);
|
| 226 |
+
}
|
| 227 |
+
index = 0;
|
| 228 |
+
offset = inner_looper.offset;
|
| 229 |
+
if (extra > 0) {
|
| 230 |
+
next(extra, shape, strides);
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
OffsetT location() {
|
| 236 |
+
return offset;
|
| 237 |
+
}
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
template <typename OffsetT>
|
| 241 |
+
struct LoopedElemToLoc<1, OffsetT, true> {
|
| 242 |
+
int dim;
|
| 243 |
+
OffsetT offset{0};
|
| 244 |
+
uint index{0};
|
| 245 |
+
|
| 246 |
+
LoopedElemToLoc(int dim) : dim(dim) {}
|
| 247 |
+
|
| 248 |
+
void next(const constant int* shape, const constant int64_t* strides) {
|
| 249 |
+
index++;
|
| 250 |
+
if (dim > 1) {
|
| 251 |
+
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
| 252 |
+
} else {
|
| 253 |
+
offset += OffsetT(strides[0]);
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
void next(int n, const constant int* shape, const constant int64_t* strides) {
|
| 258 |
+
index += n;
|
| 259 |
+
if (dim > 1) {
|
| 260 |
+
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
|
| 261 |
+
} else {
|
| 262 |
+
offset = index * OffsetT(strides[0]);
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
OffsetT location() {
|
| 267 |
+
return offset;
|
| 268 |
+
}
|
| 269 |
+
};
|
| 270 |
+
|
| 271 |
+
template <typename OffsetT>
|
| 272 |
+
struct LoopedElemToLoc<1, OffsetT, false> {
|
| 273 |
+
OffsetT offset{0};
|
| 274 |
+
|
| 275 |
+
LoopedElemToLoc(int) {}
|
| 276 |
+
|
| 277 |
+
void next(const constant int*, const constant int64_t* strides) {
|
| 278 |
+
offset += OffsetT(strides[0]);
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
void next(int n, const constant int*, const constant int64_t* strides) {
|
| 282 |
+
offset += n * OffsetT(strides[0]);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
OffsetT location() {
|
| 286 |
+
return offset;
|
| 287 |
+
}
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 291 |
+
// Calculation utils
|
| 292 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 293 |
+
|
| 294 |
+
/** Compute ceil((float)N/(float)M) */
|
| 295 |
+
template <typename T, typename U>
|
| 296 |
+
inline T ceildiv(T N, U M) {
|
| 297 |
+
return (N + M - 1) / M;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
| 301 |
+
inline float log1p(float x) {
|
| 302 |
+
float xp1 = 1.0f + x;
|
| 303 |
+
if (xp1 == Limits<float>::max) {
|
| 304 |
+
return Limits<float>::max;
|
| 305 |
+
}
|
| 306 |
+
if (xp1 == 1.0f) {
|
| 307 |
+
return x;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
inline bfloat16_t log1p(bfloat16_t x) {
|
| 314 |
+
float xp1 = 1.0f + static_cast<float>(x);
|
| 315 |
+
if (xp1 == Limits<float>::max) {
|
| 316 |
+
return Limits<bfloat16_t>::max;
|
| 317 |
+
}
|
| 318 |
+
if (xp1 == 1.0f) {
|
| 319 |
+
return x;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 326 |
+
// SIMD shuffle ops
|
| 327 |
+
///////////////////////////////////////////////////////////////////////////////
|
| 328 |
+
|
| 329 |
+
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
| 330 |
+
return as_type<uint64_t>(
|
| 331 |
+
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
| 335 |
+
return as_type<int64_t>(
|
| 336 |
+
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
| 340 |
+
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
|
| 344 |
+
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
|
| 348 |
+
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
inline bool simd_shuffle_up(bool data, uint16_t delta) {
|
| 352 |
+
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
inline uint64_t
|
| 356 |
+
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
|
| 357 |
+
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
|
| 358 |
+
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
inline int64_t
|
| 362 |
+
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
|
| 363 |
+
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
|
| 364 |
+
as_type<uint2>(data), as_type<uint2>(filling), delta));
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
|
| 368 |
+
return simd_shuffle_and_fill_up(
|
| 369 |
+
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
|
| 373 |
+
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
|
| 377 |
+
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
inline bool simd_shuffle(bool data, uint16_t lane) {
|
| 381 |
+
return simd_shuffle(static_cast<uint32_t>(data), lane);
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
// std::conditional is not included with Metal
|
| 385 |
+
template <bool condition, typename T, typename U>
|
| 386 |
+
struct ConditionalType {
|
| 387 |
+
using type = U;
|
| 388 |
+
};
|
| 389 |
+
|
| 390 |
+
template <typename T, typename U>
|
| 391 |
+
struct ConditionalType<true, T, U> {
|
| 392 |
+
using type = T;
|
| 393 |
+
};
|
build.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "bitsandbytes_mps"
|
| 3 |
+
backends = ["metal"]
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
minver = "2.9"
|
| 7 |
+
src = [
|
| 8 |
+
"torch-ext/torch_binding.cpp",
|
| 9 |
+
"torch-ext/torch_binding.h",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
[general.hub]
|
| 13 |
+
repo-id = "kernels-community/bitsandbytes-mps"
|
| 14 |
+
|
| 15 |
+
[kernel.bitsandbytes_mps]
|
| 16 |
+
|
| 17 |
+
depends = ["torch"]
|
| 18 |
+
backend = "metal"
|
| 19 |
+
|
| 20 |
+
src = [
|
| 21 |
+
# Utility headers (from MLX)
|
| 22 |
+
"bitsandbytes_mps/bf16.h",
|
| 23 |
+
"bitsandbytes_mps/bf16_math.h",
|
| 24 |
+
"bitsandbytes_mps/complex.h",
|
| 25 |
+
"bitsandbytes_mps/defines.h",
|
| 26 |
+
"bitsandbytes_mps/utils.h",
|
| 27 |
+
|
| 28 |
+
# GEMM infrastructure (from MLX steel)
|
| 29 |
+
"bitsandbytes_mps/gemm/defines.h",
|
| 30 |
+
"bitsandbytes_mps/gemm/gemm.h",
|
| 31 |
+
"bitsandbytes_mps/gemm/loader.h",
|
| 32 |
+
"bitsandbytes_mps/gemm/mma.h",
|
| 33 |
+
"bitsandbytes_mps/gemm/params.h",
|
| 34 |
+
"bitsandbytes_mps/gemm/transforms.h",
|
| 35 |
+
"bitsandbytes_mps/gemm/utils.h",
|
| 36 |
+
"bitsandbytes_mps/gemm/utils/integral_constant.h",
|
| 37 |
+
"bitsandbytes_mps/gemm/utils/type_traits.h",
|
| 38 |
+
|
| 39 |
+
# Quantized matmul utilities (from MLX)
|
| 40 |
+
"bitsandbytes_mps/quantized_utils.h",
|
| 41 |
+
|
| 42 |
+
# BnB-specific: codebook types, kernel logic, Metal shaders, dispatch
|
| 43 |
+
"bitsandbytes_mps/bnb_types.h",
|
| 44 |
+
"bitsandbytes_mps/bnb_quantized.h",
|
| 45 |
+
"bitsandbytes_mps/bnb_quantized.metal",
|
| 46 |
+
"bitsandbytes_mps/bnb_quantized.mm",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
include = ["bitsandbytes_mps"]
|
build/torch210-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c9be20185521046ee042d66544cf94fa448c0e1c0455217ec81cef718d264ed9
|
| 3 |
+
size 845176
|
build/torch210-metal-aarch64-darwin/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _bitsandbytes_mps_9811962_dirty
|
| 3 |
+
ops = torch.ops._bitsandbytes_mps_9811962_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_bitsandbytes_mps_9811962_dirty::{op_name}"
|
build/torch29-metal-aarch64-darwin/_bitsandbytes_mps_9811962_dirty.abi3.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:be7a2bbf3cae711200855b297de2f3ba3d47379bf2ce52c61dd6cc3053075047
|
| 3 |
+
size 844504
|
build/torch29-metal-aarch64-darwin/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _bitsandbytes_mps_9811962_dirty
|
| 3 |
+
ops = torch.ops._bitsandbytes_mps_9811962_dirty
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_bitsandbytes_mps_9811962_dirty::{op_name}"
|
flake.lock
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1765121682,
|
| 6 |
+
"narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-utils": {
|
| 19 |
+
"inputs": {
|
| 20 |
+
"systems": "systems"
|
| 21 |
+
},
|
| 22 |
+
"locked": {
|
| 23 |
+
"lastModified": 1731533236,
|
| 24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 25 |
+
"owner": "numtide",
|
| 26 |
+
"repo": "flake-utils",
|
| 27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 28 |
+
"type": "github"
|
| 29 |
+
},
|
| 30 |
+
"original": {
|
| 31 |
+
"owner": "numtide",
|
| 32 |
+
"repo": "flake-utils",
|
| 33 |
+
"type": "github"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"kernel-builder": {
|
| 37 |
+
"inputs": {
|
| 38 |
+
"flake-compat": "flake-compat",
|
| 39 |
+
"flake-utils": "flake-utils",
|
| 40 |
+
"nixpkgs": "nixpkgs"
|
| 41 |
+
},
|
| 42 |
+
"locked": {
|
| 43 |
+
"lastModified": 1769448133,
|
| 44 |
+
"narHash": "sha256-XOp8+8u7fmXn1f63mJ40dPj/OHPMKtL9o4q7y0CUZFU=",
|
| 45 |
+
"owner": "huggingface",
|
| 46 |
+
"repo": "kernel-builder",
|
| 47 |
+
"rev": "078351df6e0fddb4a1a41ba3ffb8b804f58c4c6a",
|
| 48 |
+
"type": "github"
|
| 49 |
+
},
|
| 50 |
+
"original": {
|
| 51 |
+
"owner": "huggingface",
|
| 52 |
+
"repo": "kernel-builder",
|
| 53 |
+
"type": "github"
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"nixpkgs": {
|
| 57 |
+
"locked": {
|
| 58 |
+
"lastModified": 1766341660,
|
| 59 |
+
"narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
|
| 60 |
+
"owner": "NixOS",
|
| 61 |
+
"repo": "nixpkgs",
|
| 62 |
+
"rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
|
| 63 |
+
"type": "github"
|
| 64 |
+
},
|
| 65 |
+
"original": {
|
| 66 |
+
"owner": "NixOS",
|
| 67 |
+
"ref": "nixos-unstable-small",
|
| 68 |
+
"repo": "nixpkgs",
|
| 69 |
+
"type": "github"
|
| 70 |
+
}
|
| 71 |
+
},
|
| 72 |
+
"root": {
|
| 73 |
+
"inputs": {
|
| 74 |
+
"kernel-builder": "kernel-builder"
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
"systems": {
|
| 78 |
+
"locked": {
|
| 79 |
+
"lastModified": 1681028828,
|
| 80 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 81 |
+
"owner": "nix-systems",
|
| 82 |
+
"repo": "default",
|
| 83 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 84 |
+
"type": "github"
|
| 85 |
+
},
|
| 86 |
+
"original": {
|
| 87 |
+
"owner": "nix-systems",
|
| 88 |
+
"repo": "default",
|
| 89 |
+
"type": "github"
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"root": "root",
|
| 94 |
+
"version": 7
|
| 95 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for triton-kernels kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
path = ./.;
|
| 15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
};
|
| 17 |
+
}
|
tests/__pycache__/test_bnb_mps.cpython-312-pytest-8.4.2.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
tests/test_bnb_mps.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for bitsandbytes MPS 4-bit quantization kernels."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from bitsandbytes_mps import (
|
| 7 |
+
FP4,
|
| 8 |
+
NF4,
|
| 9 |
+
dequantize_4bit,
|
| 10 |
+
gemm_4bit,
|
| 11 |
+
gemv_4bit,
|
| 12 |
+
linear_4bit,
|
| 13 |
+
quantize_4bit,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
# NF4 codebook values (matching bnb_types.h)
|
| 17 |
+
NF4_CODEBOOK = [
|
| 18 |
+
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
|
| 19 |
+
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
|
| 20 |
+
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
|
| 21 |
+
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
| 22 |
+
0.7229568362236023, 1.0,
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
FP4_CODEBOOK = [
|
| 26 |
+
0.0, 0.005208333333, 0.66666667, 1.0, 0.33333333, 0.5, 0.16666667, 0.25,
|
| 27 |
+
0.0, -0.005208333333, -0.66666667, -1.0, -0.33333333, -0.5, -0.16666667,
|
| 28 |
+
-0.25,
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
DEVICE = "mps"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _reference_quantize_nf4(x_flat, blocksize):
|
| 35 |
+
"""Reference Python implementation of NF4 blockwise quantization."""
|
| 36 |
+
n = x_flat.numel()
|
| 37 |
+
num_blocks = (n + blocksize - 1) // blocksize
|
| 38 |
+
absmax = torch.zeros(num_blocks, dtype=torch.float32)
|
| 39 |
+
packed = torch.zeros((n + 1) // 2, dtype=torch.uint8)
|
| 40 |
+
|
| 41 |
+
codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32)
|
| 42 |
+
|
| 43 |
+
for b in range(num_blocks):
|
| 44 |
+
start = b * blocksize
|
| 45 |
+
end = min(start + blocksize, n)
|
| 46 |
+
block = x_flat[start:end].float()
|
| 47 |
+
am = block.abs().max().item()
|
| 48 |
+
absmax[b] = am
|
| 49 |
+
|
| 50 |
+
if am > 0:
|
| 51 |
+
normalized = (block / am).clamp(-1, 1)
|
| 52 |
+
else:
|
| 53 |
+
normalized = torch.zeros_like(block)
|
| 54 |
+
|
| 55 |
+
for i in range(0, end - start, 2):
|
| 56 |
+
v0 = normalized[i].item()
|
| 57 |
+
q0 = (codebook - v0).abs().argmin().item()
|
| 58 |
+
|
| 59 |
+
q1 = 0
|
| 60 |
+
if i + 1 < end - start:
|
| 61 |
+
v1 = normalized[i + 1].item()
|
| 62 |
+
q1 = (codebook - v1).abs().argmin().item()
|
| 63 |
+
|
| 64 |
+
byte_idx = (start + i) // 2
|
| 65 |
+
packed[byte_idx] = (q0 << 4) | (q1 & 0x0F)
|
| 66 |
+
|
| 67 |
+
return packed, absmax
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _reference_dequantize_nf4(packed, absmax, blocksize, numel):
|
| 71 |
+
"""Reference Python implementation of NF4 blockwise dequantization."""
|
| 72 |
+
codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32)
|
| 73 |
+
output = torch.zeros(numel, dtype=torch.float32)
|
| 74 |
+
|
| 75 |
+
for i in range(numel):
|
| 76 |
+
byte_idx = i // 2
|
| 77 |
+
block_idx = i // blocksize
|
| 78 |
+
byte_val = packed[byte_idx].item()
|
| 79 |
+
|
| 80 |
+
if i % 2 == 0:
|
| 81 |
+
nibble = (byte_val >> 4) & 0x0F
|
| 82 |
+
else:
|
| 83 |
+
nibble = byte_val & 0x0F
|
| 84 |
+
|
| 85 |
+
output[i] = codebook[nibble] * absmax[block_idx].item()
|
| 86 |
+
|
| 87 |
+
return output
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ============================================================================
|
| 91 |
+
# Quantization / Dequantization Tests
|
| 92 |
+
# ============================================================================
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@pytest.mark.parametrize("blocksize", [64, 128])
|
| 96 |
+
@pytest.mark.parametrize("quant_type", [NF4, FP4])
|
| 97 |
+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
| 98 |
+
def test_quantize_dequantize_roundtrip(blocksize, quant_type, dtype):
|
| 99 |
+
"""Test that quantize -> dequantize approximately recovers the original."""
|
| 100 |
+
torch.manual_seed(42)
|
| 101 |
+
n = 1024
|
| 102 |
+
x = torch.randn(n, dtype=dtype, device=DEVICE)
|
| 103 |
+
|
| 104 |
+
packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=quant_type)
|
| 105 |
+
|
| 106 |
+
assert packed.shape == (n // 2,)
|
| 107 |
+
assert packed.dtype == torch.uint8
|
| 108 |
+
assert absmax.dtype == torch.float32
|
| 109 |
+
assert absmax.shape == ((n + blocksize - 1) // blocksize,)
|
| 110 |
+
|
| 111 |
+
x_deq = dequantize_4bit(
|
| 112 |
+
packed, absmax, blocksize=blocksize, quant_type=quant_type,
|
| 113 |
+
numel=n, output_dtype=dtype,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
assert x_deq.shape == (n,)
|
| 117 |
+
assert x_deq.dtype == dtype
|
| 118 |
+
|
| 119 |
+
# 4-bit quantization has significant error; check correlation
|
| 120 |
+
x_cpu = x.float().cpu()
|
| 121 |
+
x_deq_cpu = x_deq.float().cpu()
|
| 122 |
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
| 123 |
+
x_cpu.unsqueeze(0), x_deq_cpu.unsqueeze(0)
|
| 124 |
+
).item()
|
| 125 |
+
assert cosine_sim > 0.95, f"Cosine similarity too low: {cosine_sim}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@pytest.mark.parametrize("blocksize", [64, 128])
|
| 129 |
+
def test_dequantize_matches_reference(blocksize):
|
| 130 |
+
"""Test dequantization matches the Python reference implementation."""
|
| 131 |
+
torch.manual_seed(123)
|
| 132 |
+
n = 256
|
| 133 |
+
x = torch.randn(n, dtype=torch.float16, device=DEVICE)
|
| 134 |
+
|
| 135 |
+
packed, absmax = quantize_4bit(x, blocksize=blocksize, quant_type=NF4)
|
| 136 |
+
|
| 137 |
+
# GPU dequantize
|
| 138 |
+
x_deq = dequantize_4bit(
|
| 139 |
+
packed, absmax, blocksize=blocksize, quant_type=NF4,
|
| 140 |
+
numel=n, output_dtype=torch.float16,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Reference dequantize (on CPU)
|
| 144 |
+
x_ref = _reference_dequantize_nf4(
|
| 145 |
+
packed.cpu(), absmax.cpu(), blocksize, n
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
torch.testing.assert_close(
|
| 149 |
+
x_deq.float().cpu(), x_ref, rtol=1e-3, atol=1e-3
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ============================================================================
|
| 154 |
+
# GEMV Tests
|
| 155 |
+
# ============================================================================
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@pytest.mark.parametrize("blocksize", [64, 128])
|
| 159 |
+
@pytest.mark.parametrize("quant_type", [NF4, FP4])
|
| 160 |
+
def test_gemv_correctness(blocksize, quant_type):
|
| 161 |
+
"""Test fused GEMV against dequantize + matmul reference."""
|
| 162 |
+
torch.manual_seed(42)
|
| 163 |
+
N, K = 256, 256
|
| 164 |
+
|
| 165 |
+
# Create weight and quantize
|
| 166 |
+
W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
|
| 167 |
+
W_flat = W.flatten()
|
| 168 |
+
packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type)
|
| 169 |
+
|
| 170 |
+
# Reshape for GEMV
|
| 171 |
+
packed_w = packed.view(N, K // 2)
|
| 172 |
+
absmax_w = absmax.view(N, -1)
|
| 173 |
+
|
| 174 |
+
# Input vector
|
| 175 |
+
x = torch.randn(K, dtype=torch.float16, device=DEVICE)
|
| 176 |
+
|
| 177 |
+
# Fused GEMV
|
| 178 |
+
y = gemv_4bit(x, packed_w, absmax_w, output_features=N,
|
| 179 |
+
blocksize=blocksize, quant_type=quant_type)
|
| 180 |
+
|
| 181 |
+
# Reference: dequantize then matmul
|
| 182 |
+
W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize,
|
| 183 |
+
quant_type=quant_type, numel=N*K,
|
| 184 |
+
output_dtype=torch.float16)
|
| 185 |
+
W_deq = W_deq.view(N, K)
|
| 186 |
+
y_ref = W_deq @ x
|
| 187 |
+
|
| 188 |
+
# Check relative error
|
| 189 |
+
rel_error = (y.float() - y_ref.float()).abs().mean() / y_ref.float().abs().mean()
|
| 190 |
+
assert rel_error < 0.05, f"GEMV relative error too high: {rel_error}"
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ============================================================================
|
| 194 |
+
# GEMM Tests
|
| 195 |
+
# ============================================================================
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@pytest.mark.parametrize("blocksize", [64, 128])
|
| 199 |
+
@pytest.mark.parametrize("quant_type", [NF4, FP4])
|
| 200 |
+
def test_gemm_correctness(blocksize, quant_type):
|
| 201 |
+
"""Test fused GEMM against dequantize + matmul reference."""
|
| 202 |
+
torch.manual_seed(42)
|
| 203 |
+
M, N, K = 8, 128, 128
|
| 204 |
+
|
| 205 |
+
W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
|
| 206 |
+
W_flat = W.flatten()
|
| 207 |
+
packed, absmax = quantize_4bit(W_flat, blocksize=blocksize, quant_type=quant_type)
|
| 208 |
+
|
| 209 |
+
packed_w = packed.view(N, K // 2)
|
| 210 |
+
absmax_w = absmax.view(N, -1)
|
| 211 |
+
|
| 212 |
+
X = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
|
| 213 |
+
|
| 214 |
+
# Fused GEMM
|
| 215 |
+
Y = gemm_4bit(X, packed_w, absmax_w, output_features=N,
|
| 216 |
+
blocksize=blocksize, quant_type=quant_type)
|
| 217 |
+
|
| 218 |
+
# Reference
|
| 219 |
+
W_deq = dequantize_4bit(packed, absmax, blocksize=blocksize,
|
| 220 |
+
quant_type=quant_type, numel=N*K,
|
| 221 |
+
output_dtype=torch.float16)
|
| 222 |
+
W_deq = W_deq.view(N, K)
|
| 223 |
+
Y_ref = X @ W_deq.T
|
| 224 |
+
|
| 225 |
+
rel_error = (Y.float() - Y_ref.float()).abs().mean() / Y_ref.float().abs().mean()
|
| 226 |
+
assert rel_error < 0.05, f"GEMM relative error too high: {rel_error}"
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ============================================================================
|
| 230 |
+
# Linear layer test
|
| 231 |
+
# ============================================================================
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def test_linear_4bit_auto_select():
|
| 235 |
+
"""Test that linear_4bit auto-selects GEMV vs GEMM."""
|
| 236 |
+
torch.manual_seed(42)
|
| 237 |
+
N, K = 128, 128
|
| 238 |
+
|
| 239 |
+
W = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
|
| 240 |
+
packed, absmax = quantize_4bit(W.flatten(), blocksize=64, quant_type=NF4)
|
| 241 |
+
packed_w = packed.view(N, K // 2)
|
| 242 |
+
absmax_w = absmax.view(N, -1)
|
| 243 |
+
|
| 244 |
+
# Single vector - should use GEMV
|
| 245 |
+
x = torch.randn(K, dtype=torch.float16, device=DEVICE)
|
| 246 |
+
y = linear_4bit(x, packed_w, absmax_w, output_features=N)
|
| 247 |
+
assert y.shape == (N,)
|
| 248 |
+
|
| 249 |
+
# Batch - should use GEMM
|
| 250 |
+
X = torch.randn(4, K, dtype=torch.float16, device=DEVICE)
|
| 251 |
+
Y = linear_4bit(X, packed_w, absmax_w, output_features=N)
|
| 252 |
+
assert Y.shape == (4, N)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
pytest.main([__file__, "-v"])
|
torch-ext/bitsandbytes_mps/__init__.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
# Quant type constants (match bitsandbytes DataType_t)
|
| 8 |
+
FP4 = 1
|
| 9 |
+
NF4 = 2
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def quantize_4bit(
|
| 13 |
+
input: torch.Tensor,
|
| 14 |
+
blocksize: int = 64,
|
| 15 |
+
quant_type: int = NF4,
|
| 16 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 17 |
+
"""Blockwise 4-bit quantization using NF4 or FP4 codebook.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
input: Input tensor on MPS device (float16, bfloat16, or float32).
|
| 21 |
+
blocksize: Number of elements per quantization block (64 or 128).
|
| 22 |
+
quant_type: FP4 (1) or NF4 (2).
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Tuple of (packed, absmax):
|
| 26 |
+
packed: uint8 tensor of packed 4-bit values [numel/2].
|
| 27 |
+
absmax: float32 tensor of per-block max absolute values.
|
| 28 |
+
"""
|
| 29 |
+
return ops.bnb_quantize_4bit(input, blocksize, quant_type)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def dequantize_4bit(
|
| 33 |
+
packed: torch.Tensor,
|
| 34 |
+
absmax: torch.Tensor,
|
| 35 |
+
blocksize: int = 64,
|
| 36 |
+
quant_type: int = NF4,
|
| 37 |
+
numel: int = -1,
|
| 38 |
+
output_dtype: torch.dtype = torch.float16,
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
"""Blockwise 4-bit dequantization using NF4 or FP4 codebook.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
packed: uint8 tensor of packed 4-bit values.
|
| 44 |
+
absmax: float32 tensor of per-block max absolute values.
|
| 45 |
+
blocksize: Number of elements per quantization block (64 or 128).
|
| 46 |
+
quant_type: FP4 (1) or NF4 (2).
|
| 47 |
+
numel: Number of elements in the original tensor.
|
| 48 |
+
If -1, inferred as packed.numel() * 2.
|
| 49 |
+
output_dtype: Output scalar type.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dequantized tensor.
|
| 53 |
+
"""
|
| 54 |
+
if numel < 0:
|
| 55 |
+
numel = packed.numel() * 2
|
| 56 |
+
return ops.bnb_dequantize_4bit(
|
| 57 |
+
packed, absmax, blocksize, quant_type, numel, output_dtype
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def gemv_4bit(
|
| 62 |
+
x: torch.Tensor,
|
| 63 |
+
w: torch.Tensor,
|
| 64 |
+
absmax: torch.Tensor,
|
| 65 |
+
output_features: int,
|
| 66 |
+
blocksize: int = 64,
|
| 67 |
+
quant_type: int = NF4,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""Fused matrix-vector multiply with 4-bit quantized weights.
|
| 70 |
+
|
| 71 |
+
Computes y = dequant(W) @ x, where W is blockwise NF4/FP4 quantized.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x: Input vector [..., K] on MPS device.
|
| 75 |
+
w: Packed weight matrix [N, K/2] (uint8) on MPS device.
|
| 76 |
+
absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
|
| 77 |
+
output_features: Number of output features (N).
|
| 78 |
+
blocksize: Quantization block size (64 or 128).
|
| 79 |
+
quant_type: FP4 (1) or NF4 (2).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Output tensor [..., N].
|
| 83 |
+
"""
|
| 84 |
+
return ops.bnb_gemv_4bit(x, w, absmax, blocksize, quant_type, output_features)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def gemm_4bit(
|
| 88 |
+
x: torch.Tensor,
|
| 89 |
+
w: torch.Tensor,
|
| 90 |
+
absmax: torch.Tensor,
|
| 91 |
+
output_features: int,
|
| 92 |
+
blocksize: int = 64,
|
| 93 |
+
quant_type: int = NF4,
|
| 94 |
+
) -> torch.Tensor:
|
| 95 |
+
"""Fused matrix-matrix multiply with 4-bit quantized transposed weights.
|
| 96 |
+
|
| 97 |
+
Computes Y = X @ dequant(W).T, where W is blockwise NF4/FP4 quantized.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
x: Input matrix [..., M, K] on MPS device.
|
| 101 |
+
w: Packed weight matrix [N, K/2] (uint8) on MPS device.
|
| 102 |
+
absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
|
| 103 |
+
output_features: Number of output features (N).
|
| 104 |
+
blocksize: Quantization block size (64 or 128).
|
| 105 |
+
quant_type: FP4 (1) or NF4 (2).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Output tensor [..., M, N].
|
| 109 |
+
"""
|
| 110 |
+
return ops.bnb_gemm_4bit(x, w, absmax, blocksize, quant_type, output_features)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def linear_4bit(
|
| 114 |
+
x: torch.Tensor,
|
| 115 |
+
w: torch.Tensor,
|
| 116 |
+
absmax: torch.Tensor,
|
| 117 |
+
output_features: int,
|
| 118 |
+
blocksize: int = 64,
|
| 119 |
+
quant_type: int = NF4,
|
| 120 |
+
bias: Optional[torch.Tensor] = None,
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
"""4-bit quantized linear layer (auto-selects GEMV or GEMM).
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x: Input tensor on MPS device.
|
| 126 |
+
w: Packed weight [N, K/2] (uint8).
|
| 127 |
+
absmax: Scales [N, ceil(K/blocksize)] (float32).
|
| 128 |
+
output_features: N.
|
| 129 |
+
blocksize: 64 or 128.
|
| 130 |
+
quant_type: FP4 (1) or NF4 (2).
|
| 131 |
+
bias: Optional bias [N].
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Output tensor.
|
| 135 |
+
"""
|
| 136 |
+
input_1d = x.dim() == 1
|
| 137 |
+
if input_1d or (x.dim() >= 2 and x.size(-2) == 1):
|
| 138 |
+
x_flat = x.view(x.size(-1)) if input_1d else x.squeeze(-2)
|
| 139 |
+
y = gemv_4bit(
|
| 140 |
+
x_flat,
|
| 141 |
+
w,
|
| 142 |
+
absmax,
|
| 143 |
+
output_features,
|
| 144 |
+
blocksize,
|
| 145 |
+
quant_type,
|
| 146 |
+
)
|
| 147 |
+
if input_1d:
|
| 148 |
+
y = y.squeeze(0)
|
| 149 |
+
elif x.dim() >= 2:
|
| 150 |
+
y = y.unsqueeze(-2)
|
| 151 |
+
else:
|
| 152 |
+
y = gemm_4bit(x, w, absmax, output_features, blocksize, quant_type)
|
| 153 |
+
|
| 154 |
+
if bias is not None:
|
| 155 |
+
y = y + bias
|
| 156 |
+
|
| 157 |
+
return y
|
| 158 |
+
|
| 159 |
+
__all__ = [
|
| 160 |
+
"quantize_4bit",
|
| 161 |
+
"dequantize_4bit",
|
| 162 |
+
"gemv_4bit",
|
| 163 |
+
"gemm_4bit",
|
| 164 |
+
"linear_4bit",
|
| 165 |
+
]
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
// 4-bit quantization
|
| 8 |
+
ops.def(
|
| 9 |
+
"bnb_quantize_4bit(Tensor input, int blocksize, int quant_type) "
|
| 10 |
+
"-> (Tensor, Tensor)");
|
| 11 |
+
|
| 12 |
+
// 4-bit dequantization
|
| 13 |
+
ops.def(
|
| 14 |
+
"bnb_dequantize_4bit(Tensor packed, Tensor absmax, int blocksize, "
|
| 15 |
+
"int quant_type, int numel, ScalarType output_dtype) -> Tensor");
|
| 16 |
+
|
| 17 |
+
// Fused GEMV with 4-bit weights
|
| 18 |
+
ops.def(
|
| 19 |
+
"bnb_gemv_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
|
| 20 |
+
"int quant_type, int output_features) -> Tensor");
|
| 21 |
+
|
| 22 |
+
// Fused GEMM with 4-bit transposed weights
|
| 23 |
+
ops.def(
|
| 24 |
+
"bnb_gemm_4bit(Tensor x, Tensor w, Tensor absmax, int blocksize, "
|
| 25 |
+
"int quant_type, int output_features) -> Tensor");
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, MPS, ops) {
|
| 29 |
+
ops.impl("bnb_quantize_4bit", bnb_quantize_4bit);
|
| 30 |
+
ops.impl("bnb_dequantize_4bit", bnb_dequantize_4bit);
|
| 31 |
+
ops.impl("bnb_gemv_4bit", bnb_gemv_4bit);
|
| 32 |
+
ops.impl("bnb_gemm_4bit", bnb_gemm_4bit);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/ATen.h>
|
| 4 |
+
#include <tuple>
|
| 5 |
+
|
| 6 |
+
// ============================================================================
|
| 7 |
+
// Blockwise 4-bit quantization (NF4/FP4)
|
| 8 |
+
// ============================================================================
|
| 9 |
+
|
| 10 |
+
// Quantize and return both packed tensor and absmax
|
| 11 |
+
std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
|
| 12 |
+
at::Tensor input,
|
| 13 |
+
int64_t blocksize,
|
| 14 |
+
int64_t quant_type);
|
| 15 |
+
|
| 16 |
+
// ============================================================================
|
| 17 |
+
// Blockwise 4-bit dequantization
|
| 18 |
+
// ============================================================================
|
| 19 |
+
|
| 20 |
+
// Dequantize packed 4-bit tensor back to output_dtype
|
| 21 |
+
at::Tensor bnb_dequantize_4bit(
|
| 22 |
+
at::Tensor packed,
|
| 23 |
+
at::Tensor absmax,
|
| 24 |
+
int64_t blocksize,
|
| 25 |
+
int64_t quant_type,
|
| 26 |
+
int64_t numel,
|
| 27 |
+
c10::ScalarType output_dtype);
|
| 28 |
+
|
| 29 |
+
// ============================================================================
|
| 30 |
+
// Fused GEMV: y = dequant(W) @ x
|
| 31 |
+
// W: [N, K/2] packed, absmax: [N, K_groups], x: [..., K], y: [..., N]
|
| 32 |
+
// ============================================================================
|
| 33 |
+
|
| 34 |
+
at::Tensor bnb_gemv_4bit(
|
| 35 |
+
at::Tensor x,
|
| 36 |
+
at::Tensor w,
|
| 37 |
+
at::Tensor absmax,
|
| 38 |
+
int64_t blocksize,
|
| 39 |
+
int64_t quant_type,
|
| 40 |
+
int64_t output_features);
|
| 41 |
+
|
| 42 |
+
// ============================================================================
|
| 43 |
+
// Fused GEMM: Y = X @ dequant(W).T
|
| 44 |
+
// X: [M, K], W: [N, K/2] packed, absmax: [N, K_groups], Y: [M, N]
|
| 45 |
+
// ============================================================================
|
| 46 |
+
|
| 47 |
+
at::Tensor bnb_gemm_4bit(
|
| 48 |
+
at::Tensor x,
|
| 49 |
+
at::Tensor w,
|
| 50 |
+
at::Tensor absmax,
|
| 51 |
+
int64_t blocksize,
|
| 52 |
+
int64_t quant_type,
|
| 53 |
+
int64_t output_features);
|