Eric Buehler commited on
Commit ·
6c879ab
1
Parent(s): 3176d7a
Doesnt crash?
Browse files- paged-attention-metal/cache.mm +20 -9
- paged-attention-metal/paged_attention.mm +30 -12
- tests/kernels/test_attention.py +3 -3
- torch-ext/torch_binding.cpp +59 -59
paged-attention-metal/cache.mm
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
#include <torch/torch.h>
|
|
|
|
|
|
|
| 2 |
|
| 3 |
#import <Foundation/Foundation.h>
|
| 4 |
#import <Metal/Metal.h>
|
|
@@ -34,10 +36,13 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
|
| 34 |
const int64_t num_blocks = block_mapping.size(0);
|
| 35 |
|
| 36 |
@autoreleasepool {
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 39 |
|
| 40 |
-
dispatch_queue_t serialQueue =
|
| 41 |
|
| 42 |
dispatch_sync(serialQueue, ^{
|
| 43 |
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
|
|
@@ -60,7 +65,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
|
| 60 |
}
|
| 61 |
|
| 62 |
[blitEncoder endEncoding];
|
| 63 |
-
|
| 64 |
});
|
| 65 |
}
|
| 66 |
}
|
|
@@ -145,10 +150,13 @@ void copy_blocks(const std::vector<torch::Tensor>& key_caches,
|
|
| 145 |
TORCH_CHECK(pso, err.localizedDescription.UTF8String);
|
| 146 |
|
| 147 |
// --- Encode dispatch ----------------------------------------------
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 150 |
|
| 151 |
-
dispatch_queue_t q =
|
| 152 |
dispatch_sync(q, ^{
|
| 153 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 154 |
TORCH_CHECK(enc, "Failed to create compute encoder");
|
|
@@ -171,7 +179,7 @@ void copy_blocks(const std::vector<torch::Tensor>& key_caches,
|
|
| 171 |
[enc dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 172 |
[enc endEncoding];
|
| 173 |
|
| 174 |
-
|
| 175 |
});
|
| 176 |
}
|
| 177 |
}
|
|
@@ -248,10 +256,13 @@ void reshape_and_cache(
|
|
| 248 |
// -----------------------------------------------------------------
|
| 249 |
// Encode dispatch
|
| 250 |
// -----------------------------------------------------------------
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
| 252 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 253 |
|
| 254 |
-
dispatch_queue_t q =
|
| 255 |
dispatch_sync(q, ^{
|
| 256 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 257 |
TORCH_CHECK(enc, "Failed to create compute encoder");
|
|
@@ -298,7 +309,7 @@ void reshape_and_cache(
|
|
| 298 |
[enc dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 299 |
[enc endEncoding];
|
| 300 |
|
| 301 |
-
|
| 302 |
});
|
| 303 |
}
|
| 304 |
}
|
|
|
|
| 1 |
#include <torch/torch.h>
|
| 2 |
+
#include <ATen/mps/MPSStream.h>
|
| 3 |
+
#include <ATen/mps/MPSDevice.h>
|
| 4 |
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
#import <Metal/Metal.h>
|
|
|
|
| 36 |
const int64_t num_blocks = block_mapping.size(0);
|
| 37 |
|
| 38 |
@autoreleasepool {
|
| 39 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 40 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 41 |
+
|
| 42 |
+
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
|
| 43 |
TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
|
| 44 |
|
| 45 |
+
dispatch_queue_t serialQueue = stream->queue();
|
| 46 |
|
| 47 |
dispatch_sync(serialQueue, ^{
|
| 48 |
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
|
|
|
|
| 65 |
}
|
| 66 |
|
| 67 |
[blitEncoder endEncoding];
|
| 68 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 69 |
});
|
| 70 |
}
|
| 71 |
}
|
|
|
|
| 150 |
TORCH_CHECK(pso, err.localizedDescription.UTF8String);
|
| 151 |
|
| 152 |
// --- Encode dispatch ----------------------------------------------
|
| 153 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 154 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 155 |
+
|
| 156 |
+
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 157 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 158 |
|
| 159 |
+
dispatch_queue_t q = stream->queue();
|
| 160 |
dispatch_sync(q, ^{
|
| 161 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 162 |
TORCH_CHECK(enc, "Failed to create compute encoder");
|
|
|
|
| 179 |
[enc dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 180 |
[enc endEncoding];
|
| 181 |
|
| 182 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 183 |
});
|
| 184 |
}
|
| 185 |
}
|
|
|
|
| 256 |
// -----------------------------------------------------------------
|
| 257 |
// Encode dispatch
|
| 258 |
// -----------------------------------------------------------------
|
| 259 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 260 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 261 |
+
|
| 262 |
+
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 263 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 264 |
|
| 265 |
+
dispatch_queue_t q = stream->queue();
|
| 266 |
dispatch_sync(q, ^{
|
| 267 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 268 |
TORCH_CHECK(enc, "Failed to create compute encoder");
|
|
|
|
| 309 |
[enc dispatchThreads:grid threadsPerThreadgroup:tg];
|
| 310 |
[enc endEncoding];
|
| 311 |
|
| 312 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 313 |
});
|
| 314 |
}
|
| 315 |
}
|
paged-attention-metal/paged_attention.mm
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
#include <torch/torch.h>
|
|
|
|
|
|
|
| 2 |
|
| 3 |
#import <Foundation/Foundation.h>
|
| 4 |
#import <Metal/Metal.h>
|
|
@@ -43,7 +45,7 @@ static std::string getKernelName(const std::string& base_name, torch::ScalarType
|
|
| 43 |
"_nt" + std::to_string(num_threads) +
|
| 44 |
"_nsl" + std::to_string(num_simd_lanes);
|
| 45 |
|
| 46 |
-
if (partition_size > 0) {
|
| 47 |
kernel_name += "_ps" + std::to_string(partition_size);
|
| 48 |
}
|
| 49 |
|
|
@@ -60,7 +62,6 @@ static size_t calculateSharedMemorySize(int max_seq_len, int head_size, int num_
|
|
| 60 |
|
| 61 |
// Output workspace for cross-warp reduction: head_size * sizeof(float)
|
| 62 |
size_t output_size = head_size * sizeof(float);
|
| 63 |
-
|
| 64 |
return std::max(logits_size + reduction_size, output_size);
|
| 65 |
}
|
| 66 |
|
|
@@ -122,9 +123,9 @@ void paged_attention_v1(
|
|
| 122 |
// Calculate shared memory requirements
|
| 123 |
size_t shared_memory_size = calculateSharedMemorySize(max_seq_len, head_size, num_threads, num_simd_lanes);
|
| 124 |
|
| 125 |
-
// Get kernel name
|
| 126 |
std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
|
| 127 |
-
head_size, block_size, num_threads, num_simd_lanes);
|
| 128 |
|
| 129 |
@autoreleasepool {
|
| 130 |
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
@@ -156,10 +157,13 @@ void paged_attention_v1(
|
|
| 156 |
error ? error.localizedDescription.UTF8String : "unknown error");
|
| 157 |
|
| 158 |
// Setup command buffer and encoder
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
| 160 |
TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
|
| 161 |
|
| 162 |
-
dispatch_queue_t q =
|
| 163 |
dispatch_sync(q, ^{
|
| 164 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 165 |
TORCH_CHECK(enc, "Failed to create compute command encoder");
|
|
@@ -249,7 +253,7 @@ void paged_attention_v1(
|
|
| 249 |
[enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup];
|
| 250 |
[enc endEncoding];
|
| 251 |
|
| 252 |
-
|
| 253 |
});
|
| 254 |
}
|
| 255 |
}
|
|
@@ -310,8 +314,19 @@ void paged_attention_v2(
|
|
| 310 |
// Get kernel names
|
| 311 |
std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
|
| 312 |
head_size, block_size, num_threads, num_simd_lanes, partition_size);
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
@autoreleasepool {
|
| 317 |
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
@@ -327,10 +342,13 @@ void paged_attention_v2(
|
|
| 327 |
error ? error.localizedDescription.UTF8String : "unknown error");
|
| 328 |
|
| 329 |
// Setup command buffer and queue
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
| 331 |
TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
|
| 332 |
|
| 333 |
-
dispatch_queue_t q =
|
| 334 |
dispatch_sync(q, ^{
|
| 335 |
// ==================================================================
|
| 336 |
// Phase 1: Main paged attention kernel with partitioning
|
|
@@ -508,7 +526,7 @@ void paged_attention_v2(
|
|
| 508 |
[reduceEnc dispatchThreadgroups:reduceGrid threadsPerThreadgroup:reduceThreadgroup];
|
| 509 |
[reduceEnc endEncoding];
|
| 510 |
|
| 511 |
-
|
| 512 |
});
|
| 513 |
}
|
| 514 |
}
|
|
|
|
| 1 |
#include <torch/torch.h>
|
| 2 |
+
#include <ATen/mps/MPSStream.h>
|
| 3 |
+
#include <ATen/mps/MPSDevice.h>
|
| 4 |
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
#import <Metal/Metal.h>
|
|
|
|
| 45 |
"_nt" + std::to_string(num_threads) +
|
| 46 |
"_nsl" + std::to_string(num_simd_lanes);
|
| 47 |
|
| 48 |
+
if (partition_size >= 0) {
|
| 49 |
kernel_name += "_ps" + std::to_string(partition_size);
|
| 50 |
}
|
| 51 |
|
|
|
|
| 62 |
|
| 63 |
// Output workspace for cross-warp reduction: head_size * sizeof(float)
|
| 64 |
size_t output_size = head_size * sizeof(float);
|
|
|
|
| 65 |
return std::max(logits_size + reduction_size, output_size);
|
| 66 |
}
|
| 67 |
|
|
|
|
| 123 |
// Calculate shared memory requirements
|
| 124 |
size_t shared_memory_size = calculateSharedMemorySize(max_seq_len, head_size, num_threads, num_simd_lanes);
|
| 125 |
|
| 126 |
+
// Get kernel name - v1 kernels have partition_size=0 in their name
|
| 127 |
std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
|
| 128 |
+
head_size, block_size, num_threads, num_simd_lanes, partition_size);
|
| 129 |
|
| 130 |
@autoreleasepool {
|
| 131 |
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
|
|
| 157 |
error ? error.localizedDescription.UTF8String : "unknown error");
|
| 158 |
|
| 159 |
// Setup command buffer and encoder
|
| 160 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 161 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 162 |
+
|
| 163 |
+
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 164 |
TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
|
| 165 |
|
| 166 |
+
dispatch_queue_t q = stream->queue();
|
| 167 |
dispatch_sync(q, ^{
|
| 168 |
id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
|
| 169 |
TORCH_CHECK(enc, "Failed to create compute command encoder");
|
|
|
|
| 253 |
[enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup];
|
| 254 |
[enc endEncoding];
|
| 255 |
|
| 256 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 257 |
});
|
| 258 |
}
|
| 259 |
}
|
|
|
|
| 314 |
// Get kernel names
|
| 315 |
std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
|
| 316 |
head_size, block_size, num_threads, num_simd_lanes, partition_size);
|
| 317 |
+
// Reduce kernel doesn't have block_size in its name
|
| 318 |
+
std::string reduce_kernel_name = "paged_attention_v2_reduce";
|
| 319 |
+
switch (query.scalar_type()) {
|
| 320 |
+
case torch::kFloat: reduce_kernel_name += "_float"; break;
|
| 321 |
+
case torch::kHalf: reduce_kernel_name += "_half"; break;
|
| 322 |
+
case torch::kBFloat16: reduce_kernel_name += "_bfloat16_t"; break;
|
| 323 |
+
default:
|
| 324 |
+
TORCH_CHECK(false, "Unsupported dtype for paged attention: ", query.scalar_type());
|
| 325 |
+
}
|
| 326 |
+
reduce_kernel_name += "_hs" + std::to_string(head_size) +
|
| 327 |
+
"_nt" + std::to_string(num_threads) +
|
| 328 |
+
"_nsl" + std::to_string(num_simd_lanes) +
|
| 329 |
+
"_ps" + std::to_string(partition_size);
|
| 330 |
|
| 331 |
@autoreleasepool {
|
| 332 |
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
|
|
|
| 342 |
error ? error.localizedDescription.UTF8String : "unknown error");
|
| 343 |
|
| 344 |
// Setup command buffer and queue
|
| 345 |
+
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
|
| 346 |
+
TORCH_CHECK(stream, "Failed to get current MPS stream");
|
| 347 |
+
|
| 348 |
+
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 349 |
TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
|
| 350 |
|
| 351 |
+
dispatch_queue_t q = stream->queue();
|
| 352 |
dispatch_sync(q, ^{
|
| 353 |
// ==================================================================
|
| 354 |
// Phase 1: Main paged attention kernel with partitioning
|
|
|
|
| 526 |
[reduceEnc dispatchThreadgroups:reduceGrid threadsPerThreadgroup:reduceThreadgroup];
|
| 527 |
[reduceEnc endEncoding];
|
| 528 |
|
| 529 |
+
stream->synchronize(at::mps::SyncType::COMMIT);
|
| 530 |
});
|
| 531 |
}
|
| 532 |
}
|
tests/kernels/test_attention.py
CHANGED
|
@@ -228,7 +228,7 @@ def test_paged_attention(
|
|
| 228 |
64,
|
| 229 |
0,
|
| 230 |
),
|
| 231 |
-
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 232 |
)
|
| 233 |
|
| 234 |
elif version in ("v2", "rocm"):
|
|
@@ -291,7 +291,7 @@ def test_paged_attention(
|
|
| 291 |
64,
|
| 292 |
0,
|
| 293 |
),
|
| 294 |
-
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 295 |
)
|
| 296 |
|
| 297 |
else:
|
|
@@ -336,7 +336,7 @@ def test_paged_attention(
|
|
| 336 |
k_scale,
|
| 337 |
v_scale,
|
| 338 |
),
|
| 339 |
-
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
| 340 |
)
|
| 341 |
|
| 342 |
else:
|
|
|
|
| 228 |
64,
|
| 229 |
0,
|
| 230 |
),
|
| 231 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
|
| 232 |
)
|
| 233 |
|
| 234 |
elif version in ("v2", "rocm"):
|
|
|
|
| 291 |
64,
|
| 292 |
0,
|
| 293 |
),
|
| 294 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
|
| 295 |
)
|
| 296 |
|
| 297 |
else:
|
|
|
|
| 336 |
k_scale,
|
| 337 |
v_scale,
|
| 338 |
),
|
| 339 |
+
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
|
| 340 |
)
|
| 341 |
|
| 342 |
else:
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -15,103 +15,103 @@
|
|
| 15 |
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
| 16 |
|
| 17 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 32 |
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
| 33 |
#elif defined(METAL_KERNEL)
|
| 34 |
ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1);
|
| 35 |
#endif
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 50 |
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
| 51 |
#elif defined(METAL_KERNEL)
|
| 52 |
ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2);
|
| 53 |
#endif
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 59 |
ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
| 60 |
#elif defined(METAL_KERNEL)
|
| 61 |
ops.impl("swap_blocks", torch::kMPS, swap_blocks);
|
| 62 |
#endif
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 69 |
ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
| 70 |
#elif defined(METAL_KERNEL)
|
| 71 |
ops.impl("copy_blocks", torch::kMPS, copy_blocks);
|
| 72 |
#endif
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 82 |
ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
| 83 |
#elif defined(METAL_KERNEL)
|
| 84 |
ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache);
|
| 85 |
#endif
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 96 |
ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);
|
| 97 |
#elif defined(METAL_KERNEL)
|
| 98 |
ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash);
|
| 99 |
#endif
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 116 |
ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
| 117 |
#elif defined(METAL_KERNEL)
|
|
|
|
| 15 |
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
|
| 16 |
|
| 17 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 18 |
+
// Attention ops
|
| 19 |
+
// Compute the attention between an input query and the cached
|
| 20 |
+
// keys/values using PagedAttention.
|
| 21 |
+
ops.def(
|
| 22 |
+
"paged_attention_v1("
|
| 23 |
+
" Tensor! out, Tensor query, Tensor key_cache,"
|
| 24 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
| 25 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
| 26 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
| 27 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
| 28 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
| 29 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
| 30 |
+
" int blocksparse_head_sliding_step) -> ()");
|
| 31 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 32 |
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
| 33 |
#elif defined(METAL_KERNEL)
|
| 34 |
ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1);
|
| 35 |
#endif
|
| 36 |
|
| 37 |
+
// PagedAttention V2.
|
| 38 |
+
ops.def(
|
| 39 |
+
"paged_attention_v2("
|
| 40 |
+
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
| 41 |
+
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
| 42 |
+
" Tensor value_cache, int num_kv_heads, float scale,"
|
| 43 |
+
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
| 44 |
+
" int max_seq_len, Tensor? alibi_slopes,"
|
| 45 |
+
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
| 46 |
+
" int tp_rank, int blocksparse_local_blocks,"
|
| 47 |
+
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
| 48 |
+
" int blocksparse_head_sliding_step) -> ()");
|
| 49 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 50 |
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
| 51 |
#elif defined(METAL_KERNEL)
|
| 52 |
ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2);
|
| 53 |
#endif
|
| 54 |
|
| 55 |
+
// Swap in (out) the cache blocks from src to dst.
|
| 56 |
+
ops.def(
|
| 57 |
+
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
| 58 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 59 |
ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
| 60 |
#elif defined(METAL_KERNEL)
|
| 61 |
ops.impl("swap_blocks", torch::kMPS, swap_blocks);
|
| 62 |
#endif
|
| 63 |
|
| 64 |
+
// Copy the cache blocks from src to dst.
|
| 65 |
+
ops.def(
|
| 66 |
+
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
| 67 |
+
"Tensor block_mapping) -> ()");
|
| 68 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 69 |
ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
| 70 |
#elif defined(METAL_KERNEL)
|
| 71 |
ops.impl("copy_blocks", torch::kMPS, copy_blocks);
|
| 72 |
#endif
|
| 73 |
|
| 74 |
+
// Reshape the key and value tensors and cache them.
|
| 75 |
+
ops.def(
|
| 76 |
+
"reshape_and_cache(Tensor key, Tensor value,"
|
| 77 |
+
" Tensor! key_cache, Tensor! value_cache,"
|
| 78 |
+
" Tensor slot_mapping,"
|
| 79 |
+
" str kv_cache_dtype,"
|
| 80 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
| 81 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 82 |
ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
| 83 |
#elif defined(METAL_KERNEL)
|
| 84 |
ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache);
|
| 85 |
#endif
|
| 86 |
|
| 87 |
+
// Reshape the key and value tensors and cache them.
|
| 88 |
+
ops.def(
|
| 89 |
+
"reshape_and_cache_flash(Tensor key, Tensor value,"
|
| 90 |
+
" Tensor! key_cache,"
|
| 91 |
+
" Tensor! value_cache,"
|
| 92 |
+
" Tensor slot_mapping,"
|
| 93 |
+
" str kv_cache_dtype,"
|
| 94 |
+
" Tensor k_scale, Tensor v_scale) -> ()");
|
| 95 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 96 |
ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);
|
| 97 |
#elif defined(METAL_KERNEL)
|
| 98 |
ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash);
|
| 99 |
#endif
|
| 100 |
|
| 101 |
+
// Gets the specified device attribute.
|
| 102 |
+
ops.def("get_device_attribute(int attribute, int device_id) -> int");
|
| 103 |
+
ops.impl("get_device_attribute", &get_device_attribute);
|
| 104 |
|
| 105 |
+
// Gets the maximum shared memory per block device attribute.
|
| 106 |
+
ops.def(
|
| 107 |
+
"get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
|
| 108 |
+
ops.impl("get_max_shared_memory_per_block_device_attribute",
|
| 109 |
+
&get_max_shared_memory_per_block_device_attribute);
|
| 110 |
|
| 111 |
+
// Convert the key and value cache to fp8 data type.
|
| 112 |
+
ops.def(
|
| 113 |
+
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
| 114 |
+
"str kv_cache_dtype) -> ()");
|
| 115 |
#if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
|
| 116 |
ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
| 117 |
#elif defined(METAL_KERNEL)
|