Commit ·
949658a
0
Parent(s):
Add Metal rotary embedding kernel matching vLLM interface
Browse filesNew Metal kernel implementing rotary_embedding(positions, query, key,
head_size, cos_sin_cache, is_neox) with the exact signature vLLM expects.
Features:
- Supports fp16, bf16, fp32 dtypes
- NeoX style (Llama, Mistral) and GPT-J style rotation
- Arbitrary head dims (64, 128, 256)
- GQA support (separate num_heads / num_kv_heads)
- Optional key tensor (key=None skips key rotation)
- Function constants for IS_NEOX (zero-cost specialization)
- Precomputed cos_sin_cache lookup (not on-the-fly frequency computation)
Includes comprehensive tests against pure-PyTorch reference implementation.
Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)
- build.toml +18 -0
- flake.nix +17 -0
- rotary-embedding-metal/rotary_embedding.metal +118 -0
- rotary-embedding-metal/rotary_embedding.mm +198 -0
- rotary-embedding-metal/utils.metal +131 -0
- tests/test_rotary_embedding.py +210 -0
- torch-ext/rotary_embedding/__init__.py +7 -0
- torch-ext/rotary_embedding/_custom_ops.py +17 -0
- torch-ext/torch_binding.cpp +15 -0
- torch-ext/torch_binding.h +7 -0
build.toml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "rotary_embedding"
|
| 3 |
+
backends = ["metal"]
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/torch_binding.cpp",
|
| 8 |
+
"torch-ext/torch_binding.h",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
[kernel.rotary_embedding_metal]
|
| 12 |
+
backend = "metal"
|
| 13 |
+
src = [
|
| 14 |
+
"rotary-embedding-metal/rotary_embedding.metal",
|
| 15 |
+
"rotary-embedding-metal/rotary_embedding.mm",
|
| 16 |
+
"rotary-embedding-metal/utils.metal",
|
| 17 |
+
]
|
| 18 |
+
depends = ["torch"]
|
flake.nix
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for rotary embedding kernel";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
inherit self;
|
| 15 |
+
path = ./.;
|
| 16 |
+
};
|
| 17 |
+
}
|
rotary-embedding-metal/rotary_embedding.metal
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
#include "utils.metal"
|
| 3 |
+
|
| 4 |
+
using namespace metal;
|
| 5 |
+
|
| 6 |
+
// Function constants for compile-time specialization.
|
| 7 |
+
// IS_NEOX: true for GPT-NeoX style (Llama, Mistral), false for GPT-J style.
|
| 8 |
+
constant bool IS_NEOX [[function_constant(0)]];
|
| 9 |
+
|
| 10 |
+
// Rotary embedding kernel.
|
| 11 |
+
//
|
| 12 |
+
// Each threadgroup processes one token. Threads within the threadgroup
|
| 13 |
+
// are mapped to (head_idx, rot_offset) pairs covering both query and key.
|
| 14 |
+
//
|
| 15 |
+
// The cos_sin_cache layout is [max_position, rot_dim] where:
|
| 16 |
+
// cache[pos, 0:rot_dim/2] = cos values
|
| 17 |
+
// cache[pos, rot_dim/2:rot_dim] = sin values
|
| 18 |
+
//
|
| 19 |
+
// For NeoX style (IS_NEOX=true):
|
| 20 |
+
// x_index = rot_offset, y_index = embed_dim + rot_offset
|
| 21 |
+
// For GPT-J style (IS_NEOX=false):
|
| 22 |
+
// x_index = 2 * rot_offset, y_index = 2 * rot_offset + 1
|
| 23 |
+
template <typename scalar_t>
|
| 24 |
+
kernel void rotary_embedding_kernel(
|
| 25 |
+
const device int64_t *positions [[buffer(0)]],
|
| 26 |
+
device scalar_t *query [[buffer(1)]],
|
| 27 |
+
device scalar_t *key [[buffer(2)]],
|
| 28 |
+
const device scalar_t *cos_sin_cache [[buffer(3)]],
|
| 29 |
+
const device int &rot_dim [[buffer(4)]],
|
| 30 |
+
const device int64_t &query_stride [[buffer(5)]],
|
| 31 |
+
const device int64_t &key_stride [[buffer(6)]],
|
| 32 |
+
const device int &head_size [[buffer(7)]],
|
| 33 |
+
const device int &num_heads [[buffer(8)]],
|
| 34 |
+
const device int &num_kv_heads [[buffer(9)]],
|
| 35 |
+
const device int &has_key [[buffer(10)]],
|
| 36 |
+
uint token_idx [[threadgroup_position_in_grid]],
|
| 37 |
+
uint tid [[thread_position_in_threadgroup]],
|
| 38 |
+
uint threads_per_tg [[threads_per_threadgroup]]) {
|
| 39 |
+
|
| 40 |
+
const int embed_dim = rot_dim / 2;
|
| 41 |
+
const int64_t pos = positions[token_idx];
|
| 42 |
+
const device scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
| 43 |
+
|
| 44 |
+
// Process query heads.
|
| 45 |
+
for (int i = tid; i < num_heads * embed_dim; i += threads_per_tg) {
|
| 46 |
+
const int head_idx = i / embed_dim;
|
| 47 |
+
const int rot_offset = i % embed_dim;
|
| 48 |
+
|
| 49 |
+
int x_index, y_index;
|
| 50 |
+
if (IS_NEOX) {
|
| 51 |
+
x_index = rot_offset;
|
| 52 |
+
y_index = embed_dim + rot_offset;
|
| 53 |
+
} else {
|
| 54 |
+
x_index = 2 * rot_offset;
|
| 55 |
+
y_index = 2 * rot_offset + 1;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
| 59 |
+
|
| 60 |
+
const float cos_val = static_cast<float>(cache_ptr[rot_offset]);
|
| 61 |
+
const float sin_val = static_cast<float>(cache_ptr[embed_dim + rot_offset]);
|
| 62 |
+
|
| 63 |
+
const float x = static_cast<float>(query[token_head + x_index]);
|
| 64 |
+
const float y = static_cast<float>(query[token_head + y_index]);
|
| 65 |
+
query[token_head + x_index] = static_cast<scalar_t>(x * cos_val - y * sin_val);
|
| 66 |
+
query[token_head + y_index] = static_cast<scalar_t>(y * cos_val + x * sin_val);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Process key heads (if key is provided).
|
| 70 |
+
if (has_key) {
|
| 71 |
+
for (int i = tid; i < num_kv_heads * embed_dim; i += threads_per_tg) {
|
| 72 |
+
const int head_idx = i / embed_dim;
|
| 73 |
+
const int rot_offset = i % embed_dim;
|
| 74 |
+
|
| 75 |
+
int x_index, y_index;
|
| 76 |
+
if (IS_NEOX) {
|
| 77 |
+
x_index = rot_offset;
|
| 78 |
+
y_index = embed_dim + rot_offset;
|
| 79 |
+
} else {
|
| 80 |
+
x_index = 2 * rot_offset;
|
| 81 |
+
y_index = 2 * rot_offset + 1;
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
| 85 |
+
|
| 86 |
+
const float cos_val = static_cast<float>(cache_ptr[rot_offset]);
|
| 87 |
+
const float sin_val = static_cast<float>(cache_ptr[embed_dim + rot_offset]);
|
| 88 |
+
|
| 89 |
+
const float x = static_cast<float>(key[token_head + x_index]);
|
| 90 |
+
const float y = static_cast<float>(key[token_head + y_index]);
|
| 91 |
+
key[token_head + x_index] = static_cast<scalar_t>(x * cos_val - y * sin_val);
|
| 92 |
+
key[token_head + y_index] = static_cast<scalar_t>(y * cos_val + x * sin_val);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// Instantiate kernel variants for each dtype.
|
| 98 |
+
#define instantiate_rotary_embedding(type) \
|
| 99 |
+
template [[host_name("rotary_embedding_" #type)]] [[kernel]] void \
|
| 100 |
+
rotary_embedding_kernel<type>( \
|
| 101 |
+
const device int64_t *positions [[buffer(0)]], \
|
| 102 |
+
device type *query [[buffer(1)]], \
|
| 103 |
+
device type *key [[buffer(2)]], \
|
| 104 |
+
const device type *cos_sin_cache [[buffer(3)]], \
|
| 105 |
+
const device int &rot_dim [[buffer(4)]], \
|
| 106 |
+
const device int64_t &query_stride [[buffer(5)]], \
|
| 107 |
+
const device int64_t &key_stride [[buffer(6)]], \
|
| 108 |
+
const device int &head_size [[buffer(7)]], \
|
| 109 |
+
const device int &num_heads [[buffer(8)]], \
|
| 110 |
+
const device int &num_kv_heads [[buffer(9)]], \
|
| 111 |
+
const device int &has_key [[buffer(10)]], \
|
| 112 |
+
uint token_idx [[threadgroup_position_in_grid]], \
|
| 113 |
+
uint tid [[thread_position_in_threadgroup]], \
|
| 114 |
+
uint threads_per_tg [[threads_per_threadgroup]]);
|
| 115 |
+
|
| 116 |
+
instantiate_rotary_embedding(float);
|
| 117 |
+
instantiate_rotary_embedding(half);
|
| 118 |
+
instantiate_rotary_embedding(bfloat16_t);
|
rotary-embedding-metal/rotary_embedding.mm
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/mps/MPSDevice.h>
|
| 2 |
+
#include <ATen/mps/MPSStream.h>
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
#import <Foundation/Foundation.h>
|
| 6 |
+
#import <Metal/Metal.h>
|
| 7 |
+
#include <dlfcn.h>
|
| 8 |
+
#include <string>
|
| 9 |
+
|
| 10 |
+
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
|
| 11 |
+
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
static std::string getModuleDirectory() {
|
| 15 |
+
Dl_info dl_info;
|
| 16 |
+
if (dladdr((void *)getModuleDirectory, &dl_info)) {
|
| 17 |
+
std::string path(dl_info.dli_fname);
|
| 18 |
+
size_t pos = path.find_last_of('/');
|
| 19 |
+
if (pos != std::string::npos) {
|
| 20 |
+
return path.substr(0, pos);
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
return ".";
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
| 27 |
+
std::optional<torch::Tensor> key, int64_t head_size,
|
| 28 |
+
torch::Tensor &cos_sin_cache, bool is_neox) {
|
| 29 |
+
TORCH_CHECK(query.device().is_mps(), "query must be on MPS device");
|
| 30 |
+
TORCH_CHECK(positions.device().is_mps(), "positions must be on MPS device");
|
| 31 |
+
TORCH_CHECK(cos_sin_cache.device().is_mps(),
|
| 32 |
+
"cos_sin_cache must be on MPS device");
|
| 33 |
+
|
| 34 |
+
// Determine tensor dimensions.
|
| 35 |
+
// positions: [num_tokens] or [batch, seq_len]
|
| 36 |
+
// query: [num_tokens, num_heads * head_size] or
|
| 37 |
+
// [num_tokens, num_heads, head_size]
|
| 38 |
+
const int64_t num_tokens = positions.numel();
|
| 39 |
+
|
| 40 |
+
// Flatten positions to 1D for kernel simplicity.
|
| 41 |
+
torch::Tensor positions_flat = positions.reshape({-1});
|
| 42 |
+
|
| 43 |
+
// Compute query/key strides along the token dimension.
|
| 44 |
+
// Standard layout: [num_tokens, num_heads, head_size]
|
| 45 |
+
// Batched layout: [batch, seq_len, num_heads, head_size]
|
| 46 |
+
// The token dim is at index (positions.dim() - 1) in query after
|
| 47 |
+
// accounting for batch dims, but we flatten positions so use stride(0)
|
| 48 |
+
// relative to the flattened view.
|
| 49 |
+
//
|
| 50 |
+
// For standard [num_tokens, num_heads, head_size]:
|
| 51 |
+
// query_stride = num_heads * head_size (stride along dim 0)
|
| 52 |
+
// For flattened [num_tokens, num_heads * head_size]:
|
| 53 |
+
// query_stride = num_heads * head_size (stride along dim 0)
|
| 54 |
+
int64_t query_stride = query.stride(0);
|
| 55 |
+
int64_t key_stride = key.has_value() ? key->stride(0) : 0;
|
| 56 |
+
|
| 57 |
+
// Compute num_heads from tensor size. Works for both flat and split layouts.
|
| 58 |
+
const int num_heads =
|
| 59 |
+
static_cast<int>(query.numel() / (num_tokens * head_size));
|
| 60 |
+
const int num_kv_heads =
|
| 61 |
+
key.has_value()
|
| 62 |
+
? static_cast<int>(key->numel() / (num_tokens * head_size))
|
| 63 |
+
: 0;
|
| 64 |
+
|
| 65 |
+
const int rot_dim = cos_sin_cache.size(-1);
|
| 66 |
+
const int embed_dim = rot_dim / 2;
|
| 67 |
+
const int has_key = key.has_value() ? 1 : 0;
|
| 68 |
+
|
| 69 |
+
@autoreleasepool {
|
| 70 |
+
at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
|
| 71 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 72 |
+
|
| 73 |
+
id<MTLDevice> device = stream->device();
|
| 74 |
+
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 75 |
+
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 76 |
+
|
| 77 |
+
// Load metallib.
|
| 78 |
+
std::string moduleDir = getModuleDirectory();
|
| 79 |
+
std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
|
| 80 |
+
|
| 81 |
+
NSString *metallibPathStr =
|
| 82 |
+
[NSString stringWithUTF8String:metallibPath.c_str()];
|
| 83 |
+
NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
|
| 84 |
+
NSError *error = nil;
|
| 85 |
+
id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
|
| 86 |
+
TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath,
|
| 87 |
+
error ? [NSString stringWithFormat:@": %@",
|
| 88 |
+
error.localizedDescription]
|
| 89 |
+
.UTF8String
|
| 90 |
+
: "");
|
| 91 |
+
|
| 92 |
+
// Select kernel variant based on dtype.
|
| 93 |
+
NSString *kernName = nil;
|
| 94 |
+
switch (query.scalar_type()) {
|
| 95 |
+
case torch::kFloat:
|
| 96 |
+
kernName = @"rotary_embedding_float";
|
| 97 |
+
break;
|
| 98 |
+
case torch::kHalf:
|
| 99 |
+
kernName = @"rotary_embedding_half";
|
| 100 |
+
break;
|
| 101 |
+
case torch::kBFloat16:
|
| 102 |
+
kernName = @"rotary_embedding_bfloat16_t";
|
| 103 |
+
break;
|
| 104 |
+
default:
|
| 105 |
+
TORCH_CHECK(false, "Unsupported dtype for rotary_embedding: ",
|
| 106 |
+
query.scalar_type());
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Set function constant for IS_NEOX.
|
| 110 |
+
MTLFunctionConstantValues *constants =
|
| 111 |
+
[[MTLFunctionConstantValues alloc] init];
|
| 112 |
+
[constants setConstantValue:&is_neox type:MTLDataTypeBool atIndex:0];
|
| 113 |
+
|
| 114 |
+
id<MTLFunction> fn = [lib newFunctionWithName:kernName
|
| 115 |
+
constantValues:constants
|
| 116 |
+
error:&error];
|
| 117 |
+
TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String,
|
| 118 |
+
error ? [NSString stringWithFormat:@": %@",
|
| 119 |
+
error.localizedDescription]
|
| 120 |
+
.UTF8String
|
| 121 |
+
: "");
|
| 122 |
+
|
| 123 |
+
id<MTLComputePipelineState> pso =
|
| 124 |
+
[device newComputePipelineStateWithFunction:fn error:&error];
|
| 125 |
+
TORCH_CHECK(pso, "Failed to create pipeline state",
|
| 126 |
+
error ? [NSString stringWithFormat:@": %@",
|
| 127 |
+
error.localizedDescription]
|
| 128 |
+
.UTF8String
|
| 129 |
+
: "");
|
| 130 |
+
|
| 131 |
+
dispatch_queue_t q = stream->queue();
|
| 132 |
+
dispatch_sync(q, ^{
|
| 133 |
+
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 134 |
+
TORCH_CHECK(enc, "Failed to create compute encoder");
|
| 135 |
+
|
| 136 |
+
[enc setComputePipelineState:pso];
|
| 137 |
+
|
| 138 |
+
// Buffer 0: positions (flattened)
|
| 139 |
+
[enc setBuffer:getMTLBufferStorage(positions_flat)
|
| 140 |
+
offset:positions_flat.storage_offset() *
|
| 141 |
+
positions_flat.element_size()
|
| 142 |
+
atIndex:0];
|
| 143 |
+
|
| 144 |
+
// Buffer 1: query
|
| 145 |
+
[enc setBuffer:getMTLBufferStorage(query)
|
| 146 |
+
offset:query.storage_offset() * query.element_size()
|
| 147 |
+
atIndex:1];
|
| 148 |
+
|
| 149 |
+
// Buffer 2: key (or query as dummy if no key)
|
| 150 |
+
if (key.has_value()) {
|
| 151 |
+
[enc setBuffer:getMTLBufferStorage(*key)
|
| 152 |
+
offset:key->storage_offset() * key->element_size()
|
| 153 |
+
atIndex:2];
|
| 154 |
+
} else {
|
| 155 |
+
// Pass query buffer as dummy; has_key=0 ensures it's never accessed
|
| 156 |
+
[enc setBuffer:getMTLBufferStorage(query)
|
| 157 |
+
offset:query.storage_offset() * query.element_size()
|
| 158 |
+
atIndex:2];
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Buffer 3: cos_sin_cache
|
| 162 |
+
[enc setBuffer:getMTLBufferStorage(cos_sin_cache)
|
| 163 |
+
offset:cos_sin_cache.storage_offset() *
|
| 164 |
+
cos_sin_cache.element_size()
|
| 165 |
+
atIndex:3];
|
| 166 |
+
|
| 167 |
+
// Scalar parameters via setBytes.
|
| 168 |
+
const int32_t rot_dim_i32 = static_cast<int32_t>(rot_dim);
|
| 169 |
+
[enc setBytes:&rot_dim_i32 length:sizeof(int32_t) atIndex:4];
|
| 170 |
+
|
| 171 |
+
[enc setBytes:&query_stride length:sizeof(int64_t) atIndex:5];
|
| 172 |
+
[enc setBytes:&key_stride length:sizeof(int64_t) atIndex:6];
|
| 173 |
+
|
| 174 |
+
const int32_t head_size_i32 = static_cast<int32_t>(head_size);
|
| 175 |
+
[enc setBytes:&head_size_i32 length:sizeof(int32_t) atIndex:7];
|
| 176 |
+
|
| 177 |
+
const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
|
| 178 |
+
[enc setBytes:&num_heads_i32 length:sizeof(int32_t) atIndex:8];
|
| 179 |
+
|
| 180 |
+
const int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
|
| 181 |
+
[enc setBytes:&num_kv_heads_i32 length:sizeof(int32_t) atIndex:9];
|
| 182 |
+
|
| 183 |
+
const int32_t has_key_i32 = static_cast<int32_t>(has_key);
|
| 184 |
+
[enc setBytes:&has_key_i32 length:sizeof(int32_t) atIndex:10];
|
| 185 |
+
|
| 186 |
+
// Dispatch: one threadgroup per token.
|
| 187 |
+
const uint32_t threads_per_tg =
|
| 188 |
+
std::min<uint32_t>(512, std::max(num_heads, num_kv_heads) * embed_dim);
|
| 189 |
+
MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
|
| 190 |
+
MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
|
| 191 |
+
|
| 192 |
+
[enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
|
| 193 |
+
[enc endEncoding];
|
| 194 |
+
});
|
| 195 |
+
|
| 196 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 197 |
+
}
|
| 198 |
+
}
|
rotary-embedding-metal/utils.metal
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <metal_stdlib>
|
| 2 |
+
using namespace metal;
|
| 3 |
+
|
| 4 |
+
#if defined(__HAVE_BFLOAT__)
|
| 5 |
+
|
| 6 |
+
typedef bfloat bfloat16_t;
|
| 7 |
+
|
| 8 |
+
#else
|
| 9 |
+
|
| 10 |
+
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
| 11 |
+
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
| 12 |
+
_fp_encoding_traits<float>::inf_mask) {
|
| 13 |
+
return uint16_t(as_type<uint32_t>(0x7FC0));
|
| 14 |
+
}
|
| 15 |
+
uint32_t float_bits = as_type<uint32_t>(x);
|
| 16 |
+
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
| 17 |
+
return float_bits >> 16;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
| 21 |
+
return as_type<float>((uint32_t)x << 16);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
struct _MLX_BFloat16;
|
| 25 |
+
|
| 26 |
+
template <typename T>
|
| 27 |
+
static constexpr constant bool can_convert_to_bfloat =
|
| 28 |
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
| 29 |
+
|
| 30 |
+
template <typename T>
|
| 31 |
+
static constexpr constant bool can_convert_from_bfloat =
|
| 32 |
+
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
| 33 |
+
|
| 34 |
+
struct _MLX_BFloat16 {
|
| 35 |
+
uint16_t bits_;
|
| 36 |
+
_MLX_BFloat16() thread = default;
|
| 37 |
+
_MLX_BFloat16() threadgroup = default;
|
| 38 |
+
_MLX_BFloat16() device = default;
|
| 39 |
+
_MLX_BFloat16() constant = default;
|
| 40 |
+
|
| 41 |
+
struct bits_to_bfloat_struct {};
|
| 42 |
+
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
| 43 |
+
return bits_to_bfloat_struct();
|
| 44 |
+
}
|
| 45 |
+
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
| 46 |
+
: bits_(bits) {}
|
| 47 |
+
|
| 48 |
+
template <typename T,
|
| 49 |
+
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
| 50 |
+
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
| 51 |
+
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
| 52 |
+
|
| 53 |
+
template <typename T,
|
| 54 |
+
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
| 55 |
+
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
| 56 |
+
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
| 57 |
+
|
| 58 |
+
template <typename T,
|
| 59 |
+
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
| 60 |
+
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
| 61 |
+
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
| 62 |
+
|
| 63 |
+
template <typename T,
|
| 64 |
+
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
| 65 |
+
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
| 66 |
+
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
| 67 |
+
|
| 68 |
+
template <typename T,
|
| 69 |
+
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
| 70 |
+
constexpr METAL_FUNC operator T() const thread {
|
| 71 |
+
return static_cast<T>(bfloat_bits_to_float(bits_));
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <typename T,
|
| 75 |
+
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
| 76 |
+
constexpr METAL_FUNC operator T() const threadgroup {
|
| 77 |
+
return static_cast<T>(bfloat_bits_to_float(bits_));
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
template <typename T,
|
| 81 |
+
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
| 82 |
+
constexpr METAL_FUNC operator T() const device {
|
| 83 |
+
return static_cast<T>(bfloat_bits_to_float(bits_));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
template <typename T,
|
| 87 |
+
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
| 88 |
+
constexpr METAL_FUNC operator T() constant {
|
| 89 |
+
return static_cast<T>(bfloat_bits_to_float(bits_));
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
| 94 |
+
return -static_cast<float>(x);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
| 98 |
+
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
| 99 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
| 103 |
+
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
| 104 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 105 |
+
} \
|
| 106 |
+
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
| 107 |
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
#define bfloat_binop(_op_, _operator_) \
|
| 111 |
+
bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
|
| 112 |
+
_MLX_BFloat16, float); \
|
| 113 |
+
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
| 114 |
+
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
| 115 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
| 116 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
| 117 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
| 118 |
+
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
| 119 |
+
|
| 120 |
+
bfloat_binop(+, operator+);
|
| 121 |
+
bfloat_binop(-, operator-);
|
| 122 |
+
bfloat_binop(*, operator*);
|
| 123 |
+
bfloat_binop(/, operator/);
|
| 124 |
+
|
| 125 |
+
#undef bfloat_binop_base
|
| 126 |
+
#undef bfloat_binop_helper
|
| 127 |
+
#undef bfloat_binop
|
| 128 |
+
|
| 129 |
+
typedef struct _MLX_BFloat16 bfloat16_t;
|
| 130 |
+
|
| 131 |
+
#endif
|
tests/test_rotary_embedding.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Metal rotary embedding kernel.
|
| 2 |
+
|
| 3 |
+
Validates correctness against a pure-PyTorch reference implementation
|
| 4 |
+
for both NeoX (Llama/Mistral) and GPT-J rotation styles.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import rotary_embedding as ops
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _is_mps_available() -> bool:
|
| 14 |
+
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if _is_mps_available():
|
| 18 |
+
DEVICES = ["mps"]
|
| 19 |
+
else:
|
| 20 |
+
DEVICES = [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
|
| 21 |
+
|
| 22 |
+
DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
| 23 |
+
HEAD_SIZES = [64, 128, 256]
|
| 24 |
+
NUM_HEADS = [8, 32]
|
| 25 |
+
NUM_KV_HEADS = [1, 8] # GQA and MHA
|
| 26 |
+
IS_NEOX = [True, False]
|
| 27 |
+
NUM_TOKENS = [1, 7, 32]
|
| 28 |
+
MAX_POSITION = 8192
|
| 29 |
+
ROTARY_DIM_FRACTIONS = [1.0] # Full rotation; 0.5 for partial
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _build_cos_sin_cache(
|
| 33 |
+
max_position: int,
|
| 34 |
+
rotary_dim: int,
|
| 35 |
+
dtype: torch.dtype,
|
| 36 |
+
device: str,
|
| 37 |
+
base: float = 10000.0,
|
| 38 |
+
) -> torch.Tensor:
|
| 39 |
+
"""Build a cos/sin cache matching vLLM's convention."""
|
| 40 |
+
inv_freq = 1.0 / (
|
| 41 |
+
base ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim)
|
| 42 |
+
)
|
| 43 |
+
t = torch.arange(max_position, dtype=torch.float32)
|
| 44 |
+
freqs = torch.outer(t, inv_freq) # [max_position, rotary_dim/2]
|
| 45 |
+
cos_vals = freqs.cos()
|
| 46 |
+
sin_vals = freqs.sin()
|
| 47 |
+
cache = torch.cat([cos_vals, sin_vals], dim=-1) # [max_position, rotary_dim]
|
| 48 |
+
return cache.to(dtype=dtype, device=device)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _ref_rotary_embedding(
|
| 52 |
+
positions: torch.Tensor,
|
| 53 |
+
query: torch.Tensor,
|
| 54 |
+
key: torch.Tensor | None,
|
| 55 |
+
head_size: int,
|
| 56 |
+
cos_sin_cache: torch.Tensor,
|
| 57 |
+
is_neox: bool,
|
| 58 |
+
) -> None:
|
| 59 |
+
"""Pure-PyTorch reference implementation of rotary embedding."""
|
| 60 |
+
rot_dim = cos_sin_cache.shape[-1]
|
| 61 |
+
embed_dim = rot_dim // 2
|
| 62 |
+
|
| 63 |
+
num_tokens = positions.numel()
|
| 64 |
+
positions_flat = positions.reshape(-1)
|
| 65 |
+
|
| 66 |
+
for t in range(num_tokens):
|
| 67 |
+
pos = positions_flat[t].item()
|
| 68 |
+
cos_vals = cos_sin_cache[pos, :embed_dim].float()
|
| 69 |
+
sin_vals = cos_sin_cache[pos, embed_dim:].float()
|
| 70 |
+
|
| 71 |
+
# Apply to query heads.
|
| 72 |
+
num_heads = query.shape[-2]
|
| 73 |
+
for h in range(num_heads):
|
| 74 |
+
for d in range(embed_dim):
|
| 75 |
+
if is_neox:
|
| 76 |
+
x_idx, y_idx = d, embed_dim + d
|
| 77 |
+
else:
|
| 78 |
+
x_idx, y_idx = 2 * d, 2 * d + 1
|
| 79 |
+
|
| 80 |
+
x = query[t, h, x_idx].float()
|
| 81 |
+
y = query[t, h, y_idx].float()
|
| 82 |
+
query[t, h, x_idx] = (x * cos_vals[d] - y * sin_vals[d]).to(
|
| 83 |
+
query.dtype
|
| 84 |
+
)
|
| 85 |
+
query[t, h, y_idx] = (y * cos_vals[d] + x * sin_vals[d]).to(
|
| 86 |
+
query.dtype
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Apply to key heads.
|
| 90 |
+
if key is not None:
|
| 91 |
+
num_kv_heads = key.shape[-2]
|
| 92 |
+
for h in range(num_kv_heads):
|
| 93 |
+
for d in range(embed_dim):
|
| 94 |
+
if is_neox:
|
| 95 |
+
x_idx, y_idx = d, embed_dim + d
|
| 96 |
+
else:
|
| 97 |
+
x_idx, y_idx = 2 * d, 2 * d + 1
|
| 98 |
+
|
| 99 |
+
x = key[t, h, x_idx].float()
|
| 100 |
+
y = key[t, h, y_idx].float()
|
| 101 |
+
key[t, h, x_idx] = (x * cos_vals[d] - y * sin_vals[d]).to(
|
| 102 |
+
key.dtype
|
| 103 |
+
)
|
| 104 |
+
key[t, h, y_idx] = (y * cos_vals[d] + x * sin_vals[d]).to(
|
| 105 |
+
key.dtype
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@pytest.mark.parametrize("device", DEVICES)
|
| 110 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 111 |
+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
| 112 |
+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
| 113 |
+
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
|
| 114 |
+
@pytest.mark.parametrize("is_neox", IS_NEOX)
|
| 115 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 116 |
+
@torch.inference_mode()
|
| 117 |
+
def test_rotary_embedding(
|
| 118 |
+
device: str,
|
| 119 |
+
dtype: torch.dtype,
|
| 120 |
+
head_size: int,
|
| 121 |
+
num_heads: int,
|
| 122 |
+
num_kv_heads: int,
|
| 123 |
+
is_neox: bool,
|
| 124 |
+
num_tokens: int,
|
| 125 |
+
) -> None:
|
| 126 |
+
# Skip invalid GQA configs.
|
| 127 |
+
if num_heads % num_kv_heads != 0:
|
| 128 |
+
pytest.skip("num_heads must be divisible by num_kv_heads")
|
| 129 |
+
|
| 130 |
+
rotary_dim = head_size # Full rotation
|
| 131 |
+
cos_sin_cache = _build_cos_sin_cache(
|
| 132 |
+
MAX_POSITION, rotary_dim, dtype, device
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# Random positions (arbitrary, non-contiguous to test flexibility).
|
| 136 |
+
positions = torch.randint(0, MAX_POSITION, (num_tokens,), device=device)
|
| 137 |
+
|
| 138 |
+
# Random query and key tensors.
|
| 139 |
+
query = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
| 140 |
+
key = torch.randn(num_tokens, num_kv_heads, head_size, dtype=dtype, device=device)
|
| 141 |
+
|
| 142 |
+
# Clone for reference.
|
| 143 |
+
query_ref = query.clone()
|
| 144 |
+
key_ref = key.clone()
|
| 145 |
+
|
| 146 |
+
# Run kernel.
|
| 147 |
+
ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
|
| 148 |
+
|
| 149 |
+
# Run reference.
|
| 150 |
+
_ref_rotary_embedding(
|
| 151 |
+
positions.cpu() if device != "cpu" else positions,
|
| 152 |
+
query_ref.cpu() if device != "cpu" else query_ref,
|
| 153 |
+
key_ref.cpu() if device != "cpu" else key_ref,
|
| 154 |
+
head_size,
|
| 155 |
+
cos_sin_cache.cpu() if device != "cpu" else cos_sin_cache,
|
| 156 |
+
is_neox,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Compare. Use relaxed tolerances for fp16/bf16.
|
| 160 |
+
if dtype == torch.float32:
|
| 161 |
+
atol, rtol = 1e-5, 1e-5
|
| 162 |
+
elif dtype == torch.float16:
|
| 163 |
+
atol, rtol = 1e-3, 1e-3
|
| 164 |
+
else: # bfloat16
|
| 165 |
+
atol, rtol = 2e-2, 2e-2
|
| 166 |
+
|
| 167 |
+
query_ref_dev = query_ref.to(device=device)
|
| 168 |
+
key_ref_dev = key_ref.to(device=device)
|
| 169 |
+
|
| 170 |
+
torch.testing.assert_close(query, query_ref_dev, atol=atol, rtol=rtol)
|
| 171 |
+
torch.testing.assert_close(key, key_ref_dev, atol=atol, rtol=rtol)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@pytest.mark.parametrize("device", DEVICES)
|
| 175 |
+
@pytest.mark.parametrize("dtype", [torch.float32])
|
| 176 |
+
@pytest.mark.parametrize("is_neox", [True])
|
| 177 |
+
@torch.inference_mode()
|
| 178 |
+
def test_rotary_embedding_no_key(
|
| 179 |
+
device: str,
|
| 180 |
+
dtype: torch.dtype,
|
| 181 |
+
is_neox: bool,
|
| 182 |
+
) -> None:
|
| 183 |
+
"""Test that passing key=None works correctly."""
|
| 184 |
+
head_size = 128
|
| 185 |
+
num_heads = 8
|
| 186 |
+
num_tokens = 4
|
| 187 |
+
rotary_dim = head_size
|
| 188 |
+
cos_sin_cache = _build_cos_sin_cache(
|
| 189 |
+
MAX_POSITION, rotary_dim, dtype, device
|
| 190 |
+
)
|
| 191 |
+
positions = torch.randint(0, MAX_POSITION, (num_tokens,), device=device)
|
| 192 |
+
query = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
| 193 |
+
|
| 194 |
+
query_ref = query.clone()
|
| 195 |
+
|
| 196 |
+
# Run kernel with key=None.
|
| 197 |
+
ops.rotary_embedding(positions, query, None, head_size, cos_sin_cache, is_neox)
|
| 198 |
+
|
| 199 |
+
# Run reference with key=None.
|
| 200 |
+
_ref_rotary_embedding(
|
| 201 |
+
positions.cpu(),
|
| 202 |
+
query_ref.cpu(),
|
| 203 |
+
None,
|
| 204 |
+
head_size,
|
| 205 |
+
cos_sin_cache.cpu(),
|
| 206 |
+
is_neox,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
query_ref_dev = query_ref.to(device=device)
|
| 210 |
+
torch.testing.assert_close(query, query_ref_dev, atol=1e-5, rtol=1e-5)
|
torch-ext/rotary_embedding/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._custom_ops import rotary_embedding
|
| 2 |
+
from ._ops import ops
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"ops",
|
| 6 |
+
"rotary_embedding",
|
| 7 |
+
]
|
torch-ext/rotary_embedding/_custom_ops.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def rotary_embedding(
|
| 9 |
+
positions: torch.Tensor,
|
| 10 |
+
query: torch.Tensor,
|
| 11 |
+
key: Optional[torch.Tensor],
|
| 12 |
+
head_size: int,
|
| 13 |
+
cos_sin_cache: torch.Tensor,
|
| 14 |
+
is_neox: bool,
|
| 15 |
+
) -> None:
|
| 16 |
+
ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
|
| 17 |
+
is_neox)
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
ops.def("rotary_embedding(Tensor positions, Tensor! query,"
|
| 8 |
+
" Tensor!? key, int head_size,"
|
| 9 |
+
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
| 10 |
+
#if defined(METAL_KERNEL)
|
| 11 |
+
ops.impl("rotary_embedding", torch::kMPS, rotary_embedding);
|
| 12 |
+
#endif
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
| 6 |
+
std::optional<torch::Tensor> key, int64_t head_size,
|
| 7 |
+
torch::Tensor &cos_sin_cache, bool is_neox);
|