robtaylor-chipflow commited on
Commit
949658a
·
0 Parent(s):

Add Metal rotary embedding kernel matching vLLM interface

Browse files

New Metal kernel implementing rotary_embedding(positions, query, key,
head_size, cos_sin_cache, is_neox) with the exact signature vLLM expects.

Features:
- Supports fp16, bf16, fp32 dtypes
- NeoX style (Llama, Mistral) and GPT-J style rotation
- Arbitrary head dims (64, 128, 256)
- GQA support (separate num_heads / num_kv_heads)
- Optional key tensor (key=None skips key rotation)
- Function constants for IS_NEOX (zero-cost specialization)
- Precomputed cos_sin_cache lookup (not on-the-fly frequency computation)

Includes comprehensive tests against pure-PyTorch reference implementation.

Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)

build.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "rotary_embedding"
3
+ backends = ["metal"]
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.rotary_embedding_metal]
12
+ backend = "metal"
13
+ src = [
14
+ "rotary-embedding-metal/rotary_embedding.metal",
15
+ "rotary-embedding-metal/rotary_embedding.mm",
16
+ "rotary-embedding-metal/utils.metal",
17
+ ]
18
+ depends = ["torch"]
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for rotary embedding kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ inherit self;
15
+ path = ./.;
16
+ };
17
+ }
rotary-embedding-metal/rotary_embedding.metal ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ #include "utils.metal"
3
+
4
+ using namespace metal;
5
+
6
+ // Function constants for compile-time specialization.
7
+ // IS_NEOX: true for GPT-NeoX style (Llama, Mistral), false for GPT-J style.
8
+ constant bool IS_NEOX [[function_constant(0)]];
9
+
10
+ // Rotary embedding kernel.
11
+ //
12
+ // Each threadgroup processes one token. Threads within the threadgroup
13
+ // are mapped to (head_idx, rot_offset) pairs covering both query and key.
14
+ //
15
+ // The cos_sin_cache layout is [max_position, rot_dim] where:
16
+ // cache[pos, 0:rot_dim/2] = cos values
17
+ // cache[pos, rot_dim/2:rot_dim] = sin values
18
+ //
19
+ // For NeoX style (IS_NEOX=true):
20
+ // x_index = rot_offset, y_index = embed_dim + rot_offset
21
+ // For GPT-J style (IS_NEOX=false):
22
+ // x_index = 2 * rot_offset, y_index = 2 * rot_offset + 1
23
+ template <typename scalar_t>
24
+ kernel void rotary_embedding_kernel(
25
+ const device int64_t *positions [[buffer(0)]],
26
+ device scalar_t *query [[buffer(1)]],
27
+ device scalar_t *key [[buffer(2)]],
28
+ const device scalar_t *cos_sin_cache [[buffer(3)]],
29
+ const device int &rot_dim [[buffer(4)]],
30
+ const device int64_t &query_stride [[buffer(5)]],
31
+ const device int64_t &key_stride [[buffer(6)]],
32
+ const device int &head_size [[buffer(7)]],
33
+ const device int &num_heads [[buffer(8)]],
34
+ const device int &num_kv_heads [[buffer(9)]],
35
+ const device int &has_key [[buffer(10)]],
36
+ uint token_idx [[threadgroup_position_in_grid]],
37
+ uint tid [[thread_position_in_threadgroup]],
38
+ uint threads_per_tg [[threads_per_threadgroup]]) {
39
+
40
+ const int embed_dim = rot_dim / 2;
41
+ const int64_t pos = positions[token_idx];
42
+ const device scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
43
+
44
+ // Process query heads.
45
+ for (int i = tid; i < num_heads * embed_dim; i += threads_per_tg) {
46
+ const int head_idx = i / embed_dim;
47
+ const int rot_offset = i % embed_dim;
48
+
49
+ int x_index, y_index;
50
+ if (IS_NEOX) {
51
+ x_index = rot_offset;
52
+ y_index = embed_dim + rot_offset;
53
+ } else {
54
+ x_index = 2 * rot_offset;
55
+ y_index = 2 * rot_offset + 1;
56
+ }
57
+
58
+ const int64_t token_head = token_idx * query_stride + head_idx * head_size;
59
+
60
+ const float cos_val = static_cast<float>(cache_ptr[rot_offset]);
61
+ const float sin_val = static_cast<float>(cache_ptr[embed_dim + rot_offset]);
62
+
63
+ const float x = static_cast<float>(query[token_head + x_index]);
64
+ const float y = static_cast<float>(query[token_head + y_index]);
65
+ query[token_head + x_index] = static_cast<scalar_t>(x * cos_val - y * sin_val);
66
+ query[token_head + y_index] = static_cast<scalar_t>(y * cos_val + x * sin_val);
67
+ }
68
+
69
+ // Process key heads (if key is provided).
70
+ if (has_key) {
71
+ for (int i = tid; i < num_kv_heads * embed_dim; i += threads_per_tg) {
72
+ const int head_idx = i / embed_dim;
73
+ const int rot_offset = i % embed_dim;
74
+
75
+ int x_index, y_index;
76
+ if (IS_NEOX) {
77
+ x_index = rot_offset;
78
+ y_index = embed_dim + rot_offset;
79
+ } else {
80
+ x_index = 2 * rot_offset;
81
+ y_index = 2 * rot_offset + 1;
82
+ }
83
+
84
+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
85
+
86
+ const float cos_val = static_cast<float>(cache_ptr[rot_offset]);
87
+ const float sin_val = static_cast<float>(cache_ptr[embed_dim + rot_offset]);
88
+
89
+ const float x = static_cast<float>(key[token_head + x_index]);
90
+ const float y = static_cast<float>(key[token_head + y_index]);
91
+ key[token_head + x_index] = static_cast<scalar_t>(x * cos_val - y * sin_val);
92
+ key[token_head + y_index] = static_cast<scalar_t>(y * cos_val + x * sin_val);
93
+ }
94
+ }
95
+ }
96
+
97
+ // Instantiate kernel variants for each dtype.
98
+ #define instantiate_rotary_embedding(type) \
99
+ template [[host_name("rotary_embedding_" #type)]] [[kernel]] void \
100
+ rotary_embedding_kernel<type>( \
101
+ const device int64_t *positions [[buffer(0)]], \
102
+ device type *query [[buffer(1)]], \
103
+ device type *key [[buffer(2)]], \
104
+ const device type *cos_sin_cache [[buffer(3)]], \
105
+ const device int &rot_dim [[buffer(4)]], \
106
+ const device int64_t &query_stride [[buffer(5)]], \
107
+ const device int64_t &key_stride [[buffer(6)]], \
108
+ const device int &head_size [[buffer(7)]], \
109
+ const device int &num_heads [[buffer(8)]], \
110
+ const device int &num_kv_heads [[buffer(9)]], \
111
+ const device int &has_key [[buffer(10)]], \
112
+ uint token_idx [[threadgroup_position_in_grid]], \
113
+ uint tid [[thread_position_in_threadgroup]], \
114
+ uint threads_per_tg [[threads_per_threadgroup]]);
115
+
116
+ instantiate_rotary_embedding(float);
117
+ instantiate_rotary_embedding(half);
118
+ instantiate_rotary_embedding(bfloat16_t);
rotary-embedding-metal/rotary_embedding.mm ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <torch/torch.h>
4
+
5
+ #import <Foundation/Foundation.h>
6
+ #import <Metal/Metal.h>
7
+ #include <dlfcn.h>
8
+ #include <string>
9
+
10
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
11
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
12
+ }
13
+
14
+ static std::string getModuleDirectory() {
15
+ Dl_info dl_info;
16
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
17
+ std::string path(dl_info.dli_fname);
18
+ size_t pos = path.find_last_of('/');
19
+ if (pos != std::string::npos) {
20
+ return path.substr(0, pos);
21
+ }
22
+ }
23
+ return ".";
24
+ }
25
+
26
+ void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
27
+ std::optional<torch::Tensor> key, int64_t head_size,
28
+ torch::Tensor &cos_sin_cache, bool is_neox) {
29
+ TORCH_CHECK(query.device().is_mps(), "query must be on MPS device");
30
+ TORCH_CHECK(positions.device().is_mps(), "positions must be on MPS device");
31
+ TORCH_CHECK(cos_sin_cache.device().is_mps(),
32
+ "cos_sin_cache must be on MPS device");
33
+
34
+ // Determine tensor dimensions.
35
+ // positions: [num_tokens] or [batch, seq_len]
36
+ // query: [num_tokens, num_heads * head_size] or
37
+ // [num_tokens, num_heads, head_size]
38
+ const int64_t num_tokens = positions.numel();
39
+
40
+ // Flatten positions to 1D for kernel simplicity.
41
+ torch::Tensor positions_flat = positions.reshape({-1});
42
+
43
+ // Compute query/key strides along the token dimension.
44
+ // Standard layout: [num_tokens, num_heads, head_size]
45
+ // Batched layout: [batch, seq_len, num_heads, head_size]
46
+ // The token dim is at index (positions.dim() - 1) in query after
47
+ // accounting for batch dims, but we flatten positions so use stride(0)
48
+ // relative to the flattened view.
49
+ //
50
+ // For standard [num_tokens, num_heads, head_size]:
51
+ // query_stride = num_heads * head_size (stride along dim 0)
52
+ // For flattened [num_tokens, num_heads * head_size]:
53
+ // query_stride = num_heads * head_size (stride along dim 0)
54
+ int64_t query_stride = query.stride(0);
55
+ int64_t key_stride = key.has_value() ? key->stride(0) : 0;
56
+
57
+ // Compute num_heads from tensor size. Works for both flat and split layouts.
58
+ const int num_heads =
59
+ static_cast<int>(query.numel() / (num_tokens * head_size));
60
+ const int num_kv_heads =
61
+ key.has_value()
62
+ ? static_cast<int>(key->numel() / (num_tokens * head_size))
63
+ : 0;
64
+
65
+ const int rot_dim = cos_sin_cache.size(-1);
66
+ const int embed_dim = rot_dim / 2;
67
+ const int has_key = key.has_value() ? 1 : 0;
68
+
69
+ @autoreleasepool {
70
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
71
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
72
+
73
+ id<MTLDevice> device = stream->device();
74
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
75
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
76
+
77
+ // Load metallib.
78
+ std::string moduleDir = getModuleDirectory();
79
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
80
+
81
+ NSString *metallibPathStr =
82
+ [NSString stringWithUTF8String:metallibPath.c_str()];
83
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
84
+ NSError *error = nil;
85
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
86
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath,
87
+ error ? [NSString stringWithFormat:@": %@",
88
+ error.localizedDescription]
89
+ .UTF8String
90
+ : "");
91
+
92
+ // Select kernel variant based on dtype.
93
+ NSString *kernName = nil;
94
+ switch (query.scalar_type()) {
95
+ case torch::kFloat:
96
+ kernName = @"rotary_embedding_float";
97
+ break;
98
+ case torch::kHalf:
99
+ kernName = @"rotary_embedding_half";
100
+ break;
101
+ case torch::kBFloat16:
102
+ kernName = @"rotary_embedding_bfloat16_t";
103
+ break;
104
+ default:
105
+ TORCH_CHECK(false, "Unsupported dtype for rotary_embedding: ",
106
+ query.scalar_type());
107
+ }
108
+
109
+ // Set function constant for IS_NEOX.
110
+ MTLFunctionConstantValues *constants =
111
+ [[MTLFunctionConstantValues alloc] init];
112
+ [constants setConstantValue:&is_neox type:MTLDataTypeBool atIndex:0];
113
+
114
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName
115
+ constantValues:constants
116
+ error:&error];
117
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String,
118
+ error ? [NSString stringWithFormat:@": %@",
119
+ error.localizedDescription]
120
+ .UTF8String
121
+ : "");
122
+
123
+ id<MTLComputePipelineState> pso =
124
+ [device newComputePipelineStateWithFunction:fn error:&error];
125
+ TORCH_CHECK(pso, "Failed to create pipeline state",
126
+ error ? [NSString stringWithFormat:@": %@",
127
+ error.localizedDescription]
128
+ .UTF8String
129
+ : "");
130
+
131
+ dispatch_queue_t q = stream->queue();
132
+ dispatch_sync(q, ^{
133
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
134
+ TORCH_CHECK(enc, "Failed to create compute encoder");
135
+
136
+ [enc setComputePipelineState:pso];
137
+
138
+ // Buffer 0: positions (flattened)
139
+ [enc setBuffer:getMTLBufferStorage(positions_flat)
140
+ offset:positions_flat.storage_offset() *
141
+ positions_flat.element_size()
142
+ atIndex:0];
143
+
144
+ // Buffer 1: query
145
+ [enc setBuffer:getMTLBufferStorage(query)
146
+ offset:query.storage_offset() * query.element_size()
147
+ atIndex:1];
148
+
149
+ // Buffer 2: key (or query as dummy if no key)
150
+ if (key.has_value()) {
151
+ [enc setBuffer:getMTLBufferStorage(*key)
152
+ offset:key->storage_offset() * key->element_size()
153
+ atIndex:2];
154
+ } else {
155
+ // Pass query buffer as dummy; has_key=0 ensures it's never accessed
156
+ [enc setBuffer:getMTLBufferStorage(query)
157
+ offset:query.storage_offset() * query.element_size()
158
+ atIndex:2];
159
+ }
160
+
161
+ // Buffer 3: cos_sin_cache
162
+ [enc setBuffer:getMTLBufferStorage(cos_sin_cache)
163
+ offset:cos_sin_cache.storage_offset() *
164
+ cos_sin_cache.element_size()
165
+ atIndex:3];
166
+
167
+ // Scalar parameters via setBytes.
168
+ const int32_t rot_dim_i32 = static_cast<int32_t>(rot_dim);
169
+ [enc setBytes:&rot_dim_i32 length:sizeof(int32_t) atIndex:4];
170
+
171
+ [enc setBytes:&query_stride length:sizeof(int64_t) atIndex:5];
172
+ [enc setBytes:&key_stride length:sizeof(int64_t) atIndex:6];
173
+
174
+ const int32_t head_size_i32 = static_cast<int32_t>(head_size);
175
+ [enc setBytes:&head_size_i32 length:sizeof(int32_t) atIndex:7];
176
+
177
+ const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
178
+ [enc setBytes:&num_heads_i32 length:sizeof(int32_t) atIndex:8];
179
+
180
+ const int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
181
+ [enc setBytes:&num_kv_heads_i32 length:sizeof(int32_t) atIndex:9];
182
+
183
+ const int32_t has_key_i32 = static_cast<int32_t>(has_key);
184
+ [enc setBytes:&has_key_i32 length:sizeof(int32_t) atIndex:10];
185
+
186
+ // Dispatch: one threadgroup per token.
187
+ const uint32_t threads_per_tg =
188
+ std::min<uint32_t>(512, std::max(num_heads, num_kv_heads) * embed_dim);
189
+ MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
190
+ MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
191
+
192
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
193
+ [enc endEncoding];
194
+ });
195
+
196
+ stream->synchronize(at::mps::SyncType::COMMIT);
197
+ }
198
+ }
rotary-embedding-metal/utils.metal ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ #if defined(__HAVE_BFLOAT__)
5
+
6
+ typedef bfloat bfloat16_t;
7
+
8
+ #else
9
+
10
+ constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
11
+ if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
12
+ _fp_encoding_traits<float>::inf_mask) {
13
+ return uint16_t(as_type<uint32_t>(0x7FC0));
14
+ }
15
+ uint32_t float_bits = as_type<uint32_t>(x);
16
+ float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
17
+ return float_bits >> 16;
18
+ }
19
+
20
+ constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
21
+ return as_type<float>((uint32_t)x << 16);
22
+ }
23
+
24
+ struct _MLX_BFloat16;
25
+
26
+ template <typename T>
27
+ static constexpr constant bool can_convert_to_bfloat =
28
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
29
+
30
+ template <typename T>
31
+ static constexpr constant bool can_convert_from_bfloat =
32
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
33
+
34
+ struct _MLX_BFloat16 {
35
+ uint16_t bits_;
36
+ _MLX_BFloat16() thread = default;
37
+ _MLX_BFloat16() threadgroup = default;
38
+ _MLX_BFloat16() device = default;
39
+ _MLX_BFloat16() constant = default;
40
+
41
+ struct bits_to_bfloat_struct {};
42
+ static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
43
+ return bits_to_bfloat_struct();
44
+ }
45
+ constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
46
+ : bits_(bits) {}
47
+
48
+ template <typename T,
49
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
50
+ constexpr METAL_FUNC _MLX_BFloat16(T x) thread
51
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
52
+
53
+ template <typename T,
54
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
55
+ constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
56
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
57
+
58
+ template <typename T,
59
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
60
+ constexpr METAL_FUNC _MLX_BFloat16(T x) device
61
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
62
+
63
+ template <typename T,
64
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
65
+ constexpr METAL_FUNC _MLX_BFloat16(T x) constant
66
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
67
+
68
+ template <typename T,
69
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
70
+ constexpr METAL_FUNC operator T() const thread {
71
+ return static_cast<T>(bfloat_bits_to_float(bits_));
72
+ }
73
+
74
+ template <typename T,
75
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
76
+ constexpr METAL_FUNC operator T() const threadgroup {
77
+ return static_cast<T>(bfloat_bits_to_float(bits_));
78
+ }
79
+
80
+ template <typename T,
81
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
82
+ constexpr METAL_FUNC operator T() const device {
83
+ return static_cast<T>(bfloat_bits_to_float(bits_));
84
+ }
85
+
86
+ template <typename T,
87
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
88
+ constexpr METAL_FUNC operator T() constant {
89
+ return static_cast<T>(bfloat_bits_to_float(bits_));
90
+ }
91
+ };
92
+
93
+ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
94
+ return -static_cast<float>(x);
95
+ }
96
+
97
+ #define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
98
+ constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
99
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
100
+ }
101
+
102
+ #define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
103
+ constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
104
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
105
+ } \
106
+ constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
107
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
108
+ }
109
+
110
+ #define bfloat_binop(_op_, _operator_) \
111
+ bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
112
+ _MLX_BFloat16, float); \
113
+ bfloat_binop_helper(_op_, _operator_, float, float, float); \
114
+ bfloat_binop_helper(_op_, _operator_, float, half, float); \
115
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
116
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
117
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
118
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
119
+
120
+ bfloat_binop(+, operator+);
121
+ bfloat_binop(-, operator-);
122
+ bfloat_binop(*, operator*);
123
+ bfloat_binop(/, operator/);
124
+
125
+ #undef bfloat_binop_base
126
+ #undef bfloat_binop_helper
127
+ #undef bfloat_binop
128
+
129
+ typedef struct _MLX_BFloat16 bfloat16_t;
130
+
131
+ #endif
tests/test_rotary_embedding.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for Metal rotary embedding kernel.
2
+
3
+ Validates correctness against a pure-PyTorch reference implementation
4
+ for both NeoX (Llama/Mistral) and GPT-J rotation styles.
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+
10
+ import rotary_embedding as ops
11
+
12
+
13
+ def _is_mps_available() -> bool:
14
+ return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
15
+
16
+
17
+ if _is_mps_available():
18
+ DEVICES = ["mps"]
19
+ else:
20
+ DEVICES = [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
21
+
22
+ DTYPES = [torch.float32, torch.float16, torch.bfloat16]
23
+ HEAD_SIZES = [64, 128, 256]
24
+ NUM_HEADS = [8, 32]
25
+ NUM_KV_HEADS = [1, 8] # GQA and MHA
26
+ IS_NEOX = [True, False]
27
+ NUM_TOKENS = [1, 7, 32]
28
+ MAX_POSITION = 8192
29
+ ROTARY_DIM_FRACTIONS = [1.0] # Full rotation; 0.5 for partial
30
+
31
+
32
+ def _build_cos_sin_cache(
33
+ max_position: int,
34
+ rotary_dim: int,
35
+ dtype: torch.dtype,
36
+ device: str,
37
+ base: float = 10000.0,
38
+ ) -> torch.Tensor:
39
+ """Build a cos/sin cache matching vLLM's convention."""
40
+ inv_freq = 1.0 / (
41
+ base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim)
42
+ )
43
+ t = torch.arange(max_position, dtype=torch.float32)
44
+ freqs = torch.outer(t, inv_freq) # [max_position, rotary_dim/2]
45
+ cos_vals = freqs.cos()
46
+ sin_vals = freqs.sin()
47
+ cache = torch.cat([cos_vals, sin_vals], dim=-1) # [max_position, rotary_dim]
48
+ return cache.to(dtype=dtype, device=device)
49
+
50
+
51
+ def _ref_rotary_embedding(
52
+ positions: torch.Tensor,
53
+ query: torch.Tensor,
54
+ key: torch.Tensor | None,
55
+ head_size: int,
56
+ cos_sin_cache: torch.Tensor,
57
+ is_neox: bool,
58
+ ) -> None:
59
+ """Pure-PyTorch reference implementation of rotary embedding."""
60
+ rot_dim = cos_sin_cache.shape[-1]
61
+ embed_dim = rot_dim // 2
62
+
63
+ num_tokens = positions.numel()
64
+ positions_flat = positions.reshape(-1)
65
+
66
+ for t in range(num_tokens):
67
+ pos = positions_flat[t].item()
68
+ cos_vals = cos_sin_cache[pos, :embed_dim].float()
69
+ sin_vals = cos_sin_cache[pos, embed_dim:].float()
70
+
71
+ # Apply to query heads.
72
+ num_heads = query.shape[-2]
73
+ for h in range(num_heads):
74
+ for d in range(embed_dim):
75
+ if is_neox:
76
+ x_idx, y_idx = d, embed_dim + d
77
+ else:
78
+ x_idx, y_idx = 2 * d, 2 * d + 1
79
+
80
+ x = query[t, h, x_idx].float()
81
+ y = query[t, h, y_idx].float()
82
+ query[t, h, x_idx] = (x * cos_vals[d] - y * sin_vals[d]).to(
83
+ query.dtype
84
+ )
85
+ query[t, h, y_idx] = (y * cos_vals[d] + x * sin_vals[d]).to(
86
+ query.dtype
87
+ )
88
+
89
+ # Apply to key heads.
90
+ if key is not None:
91
+ num_kv_heads = key.shape[-2]
92
+ for h in range(num_kv_heads):
93
+ for d in range(embed_dim):
94
+ if is_neox:
95
+ x_idx, y_idx = d, embed_dim + d
96
+ else:
97
+ x_idx, y_idx = 2 * d, 2 * d + 1
98
+
99
+ x = key[t, h, x_idx].float()
100
+ y = key[t, h, y_idx].float()
101
+ key[t, h, x_idx] = (x * cos_vals[d] - y * sin_vals[d]).to(
102
+ key.dtype
103
+ )
104
+ key[t, h, y_idx] = (y * cos_vals[d] + x * sin_vals[d]).to(
105
+ key.dtype
106
+ )
107
+
108
+
109
+ @pytest.mark.parametrize("device", DEVICES)
110
+ @pytest.mark.parametrize("dtype", DTYPES)
111
+ @pytest.mark.parametrize("head_size", HEAD_SIZES)
112
+ @pytest.mark.parametrize("num_heads", NUM_HEADS)
113
+ @pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
114
+ @pytest.mark.parametrize("is_neox", IS_NEOX)
115
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
116
+ @torch.inference_mode()
117
+ def test_rotary_embedding(
118
+ device: str,
119
+ dtype: torch.dtype,
120
+ head_size: int,
121
+ num_heads: int,
122
+ num_kv_heads: int,
123
+ is_neox: bool,
124
+ num_tokens: int,
125
+ ) -> None:
126
+ # Skip invalid GQA configs.
127
+ if num_heads % num_kv_heads != 0:
128
+ pytest.skip("num_heads must be divisible by num_kv_heads")
129
+
130
+ rotary_dim = head_size # Full rotation
131
+ cos_sin_cache = _build_cos_sin_cache(
132
+ MAX_POSITION, rotary_dim, dtype, device
133
+ )
134
+
135
+ # Random positions (arbitrary, non-contiguous to test flexibility).
136
+ positions = torch.randint(0, MAX_POSITION, (num_tokens,), device=device)
137
+
138
+ # Random query and key tensors.
139
+ query = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
140
+ key = torch.randn(num_tokens, num_kv_heads, head_size, dtype=dtype, device=device)
141
+
142
+ # Clone for reference.
143
+ query_ref = query.clone()
144
+ key_ref = key.clone()
145
+
146
+ # Run kernel.
147
+ ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
148
+
149
+ # Run reference.
150
+ _ref_rotary_embedding(
151
+ positions.cpu() if device != "cpu" else positions,
152
+ query_ref.cpu() if device != "cpu" else query_ref,
153
+ key_ref.cpu() if device != "cpu" else key_ref,
154
+ head_size,
155
+ cos_sin_cache.cpu() if device != "cpu" else cos_sin_cache,
156
+ is_neox,
157
+ )
158
+
159
+ # Compare. Use relaxed tolerances for fp16/bf16.
160
+ if dtype == torch.float32:
161
+ atol, rtol = 1e-5, 1e-5
162
+ elif dtype == torch.float16:
163
+ atol, rtol = 1e-3, 1e-3
164
+ else: # bfloat16
165
+ atol, rtol = 2e-2, 2e-2
166
+
167
+ query_ref_dev = query_ref.to(device=device)
168
+ key_ref_dev = key_ref.to(device=device)
169
+
170
+ torch.testing.assert_close(query, query_ref_dev, atol=atol, rtol=rtol)
171
+ torch.testing.assert_close(key, key_ref_dev, atol=atol, rtol=rtol)
172
+
173
+
174
+ @pytest.mark.parametrize("device", DEVICES)
175
+ @pytest.mark.parametrize("dtype", [torch.float32])
176
+ @pytest.mark.parametrize("is_neox", [True])
177
+ @torch.inference_mode()
178
+ def test_rotary_embedding_no_key(
179
+ device: str,
180
+ dtype: torch.dtype,
181
+ is_neox: bool,
182
+ ) -> None:
183
+ """Test that passing key=None works correctly."""
184
+ head_size = 128
185
+ num_heads = 8
186
+ num_tokens = 4
187
+ rotary_dim = head_size
188
+ cos_sin_cache = _build_cos_sin_cache(
189
+ MAX_POSITION, rotary_dim, dtype, device
190
+ )
191
+ positions = torch.randint(0, MAX_POSITION, (num_tokens,), device=device)
192
+ query = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
193
+
194
+ query_ref = query.clone()
195
+
196
+ # Run kernel with key=None.
197
+ ops.rotary_embedding(positions, query, None, head_size, cos_sin_cache, is_neox)
198
+
199
+ # Run reference with key=None.
200
+ _ref_rotary_embedding(
201
+ positions.cpu(),
202
+ query_ref.cpu(),
203
+ None,
204
+ head_size,
205
+ cos_sin_cache.cpu(),
206
+ is_neox,
207
+ )
208
+
209
+ query_ref_dev = query_ref.to(device=device)
210
+ torch.testing.assert_close(query, query_ref_dev, atol=1e-5, rtol=1e-5)
torch-ext/rotary_embedding/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import rotary_embedding
2
+ from ._ops import ops
3
+
4
+ __all__ = [
5
+ "ops",
6
+ "rotary_embedding",
7
+ ]
torch-ext/rotary_embedding/_custom_ops.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ def rotary_embedding(
9
+ positions: torch.Tensor,
10
+ query: torch.Tensor,
11
+ key: Optional[torch.Tensor],
12
+ head_size: int,
13
+ cos_sin_cache: torch.Tensor,
14
+ is_neox: bool,
15
+ ) -> None:
16
+ ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
17
+ is_neox)
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("rotary_embedding(Tensor positions, Tensor! query,"
8
+ " Tensor!? key, int head_size,"
9
+ " Tensor cos_sin_cache, bool is_neox) -> ()");
10
+ #if defined(METAL_KERNEL)
11
+ ops.impl("rotary_embedding", torch::kMPS, rotary_embedding);
12
+ #endif
13
+ }
14
+
15
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
6
+ std::optional<torch::Tensor> key, int64_t head_size,
7
+ torch::Tensor &cos_sin_cache, bool is_neox);