Eric Buehler commited on
Commit
6c879ab
·
1 Parent(s): 3176d7a

Doesnt crash?

Browse files
paged-attention-metal/cache.mm CHANGED
@@ -1,4 +1,6 @@
1
  #include <torch/torch.h>
 
 
2
 
3
  #import <Foundation/Foundation.h>
4
  #import <Metal/Metal.h>
@@ -34,10 +36,13 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
34
  const int64_t num_blocks = block_mapping.size(0);
35
 
36
  @autoreleasepool {
37
- id<MTLCommandBuffer> commandBuffer = torch::mps::get_command_buffer();
 
 
 
38
  TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
39
 
40
- dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue();
41
 
42
  dispatch_sync(serialQueue, ^{
43
  id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
@@ -60,7 +65,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
60
  }
61
 
62
  [blitEncoder endEncoding];
63
- torch::mps::commit();
64
  });
65
  }
66
  }
@@ -145,10 +150,13 @@ void copy_blocks(const std::vector<torch::Tensor>& key_caches,
145
  TORCH_CHECK(pso, err.localizedDescription.UTF8String);
146
 
147
  // --- Encode dispatch ----------------------------------------------
148
- id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
 
 
 
149
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
150
 
151
- dispatch_queue_t q = torch::mps::get_dispatch_queue();
152
  dispatch_sync(q, ^{
153
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
154
  TORCH_CHECK(enc, "Failed to create compute encoder");
@@ -171,7 +179,7 @@ void copy_blocks(const std::vector<torch::Tensor>& key_caches,
171
  [enc dispatchThreads:grid threadsPerThreadgroup:tg];
172
  [enc endEncoding];
173
 
174
- torch::mps::commit();
175
  });
176
  }
177
  }
@@ -248,10 +256,13 @@ void reshape_and_cache(
248
  // -----------------------------------------------------------------
249
  // Encode dispatch
250
  // -----------------------------------------------------------------
251
- id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
 
 
 
252
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
253
 
254
- dispatch_queue_t q = torch::mps::get_dispatch_queue();
255
  dispatch_sync(q, ^{
256
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
257
  TORCH_CHECK(enc, "Failed to create compute encoder");
@@ -298,7 +309,7 @@ void reshape_and_cache(
298
  [enc dispatchThreads:grid threadsPerThreadgroup:tg];
299
  [enc endEncoding];
300
 
301
- torch::mps::commit();
302
  });
303
  }
304
  }
 
1
  #include <torch/torch.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <ATen/mps/MPSDevice.h>
4
 
5
  #import <Foundation/Foundation.h>
6
  #import <Metal/Metal.h>
 
36
  const int64_t num_blocks = block_mapping.size(0);
37
 
38
  @autoreleasepool {
39
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
40
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
41
+
42
+ id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
43
  TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
44
 
45
+ dispatch_queue_t serialQueue = stream->queue();
46
 
47
  dispatch_sync(serialQueue, ^{
48
  id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
 
65
  }
66
 
67
  [blitEncoder endEncoding];
68
+ stream->synchronize(at::mps::SyncType::COMMIT);
69
  });
70
  }
71
  }
 
150
  TORCH_CHECK(pso, err.localizedDescription.UTF8String);
151
 
152
  // --- Encode dispatch ----------------------------------------------
153
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
154
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
155
+
156
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
157
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
158
 
159
+ dispatch_queue_t q = stream->queue();
160
  dispatch_sync(q, ^{
161
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
162
  TORCH_CHECK(enc, "Failed to create compute encoder");
 
179
  [enc dispatchThreads:grid threadsPerThreadgroup:tg];
180
  [enc endEncoding];
181
 
182
+ stream->synchronize(at::mps::SyncType::COMMIT);
183
  });
184
  }
185
  }
 
256
  // -----------------------------------------------------------------
257
  // Encode dispatch
258
  // -----------------------------------------------------------------
259
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
260
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
261
+
262
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
263
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
264
 
265
+ dispatch_queue_t q = stream->queue();
266
  dispatch_sync(q, ^{
267
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
268
  TORCH_CHECK(enc, "Failed to create compute encoder");
 
309
  [enc dispatchThreads:grid threadsPerThreadgroup:tg];
310
  [enc endEncoding];
311
 
312
+ stream->synchronize(at::mps::SyncType::COMMIT);
313
  });
314
  }
315
  }
paged-attention-metal/paged_attention.mm CHANGED
@@ -1,4 +1,6 @@
1
  #include <torch/torch.h>
 
 
2
 
3
  #import <Foundation/Foundation.h>
4
  #import <Metal/Metal.h>
@@ -43,7 +45,7 @@ static std::string getKernelName(const std::string& base_name, torch::ScalarType
43
  "_nt" + std::to_string(num_threads) +
44
  "_nsl" + std::to_string(num_simd_lanes);
45
 
46
- if (partition_size > 0) {
47
  kernel_name += "_ps" + std::to_string(partition_size);
48
  }
49
 
@@ -60,7 +62,6 @@ static size_t calculateSharedMemorySize(int max_seq_len, int head_size, int num_
60
 
61
  // Output workspace for cross-warp reduction: head_size * sizeof(float)
62
  size_t output_size = head_size * sizeof(float);
63
-
64
  return std::max(logits_size + reduction_size, output_size);
65
  }
66
 
@@ -122,9 +123,9 @@ void paged_attention_v1(
122
  // Calculate shared memory requirements
123
  size_t shared_memory_size = calculateSharedMemorySize(max_seq_len, head_size, num_threads, num_simd_lanes);
124
 
125
- // Get kernel name
126
  std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
127
- head_size, block_size, num_threads, num_simd_lanes);
128
 
129
  @autoreleasepool {
130
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
@@ -156,10 +157,13 @@ void paged_attention_v1(
156
  error ? error.localizedDescription.UTF8String : "unknown error");
157
 
158
  // Setup command buffer and encoder
159
- id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
 
 
 
160
  TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
161
 
162
- dispatch_queue_t q = torch::mps::get_dispatch_queue();
163
  dispatch_sync(q, ^{
164
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
165
  TORCH_CHECK(enc, "Failed to create compute command encoder");
@@ -249,7 +253,7 @@ void paged_attention_v1(
249
  [enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup];
250
  [enc endEncoding];
251
 
252
- torch::mps::commit();
253
  });
254
  }
255
  }
@@ -310,8 +314,19 @@ void paged_attention_v2(
310
  // Get kernel names
311
  std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
312
  head_size, block_size, num_threads, num_simd_lanes, partition_size);
313
- std::string reduce_kernel_name = getKernelName("paged_attention_v2_reduce", query.scalar_type(),
314
- head_size, 0, num_threads, num_simd_lanes, partition_size);
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  @autoreleasepool {
317
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
@@ -327,10 +342,13 @@ void paged_attention_v2(
327
  error ? error.localizedDescription.UTF8String : "unknown error");
328
 
329
  // Setup command buffer and queue
330
- id<MTLCommandBuffer> cmdBuf = torch::mps::get_command_buffer();
 
 
 
331
  TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
332
 
333
- dispatch_queue_t q = torch::mps::get_dispatch_queue();
334
  dispatch_sync(q, ^{
335
  // ==================================================================
336
  // Phase 1: Main paged attention kernel with partitioning
@@ -508,7 +526,7 @@ void paged_attention_v2(
508
  [reduceEnc dispatchThreadgroups:reduceGrid threadsPerThreadgroup:reduceThreadgroup];
509
  [reduceEnc endEncoding];
510
 
511
- torch::mps::commit();
512
  });
513
  }
514
  }
 
1
  #include <torch/torch.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <ATen/mps/MPSDevice.h>
4
 
5
  #import <Foundation/Foundation.h>
6
  #import <Metal/Metal.h>
 
45
  "_nt" + std::to_string(num_threads) +
46
  "_nsl" + std::to_string(num_simd_lanes);
47
 
48
+ if (partition_size >= 0) {
49
  kernel_name += "_ps" + std::to_string(partition_size);
50
  }
51
 
 
62
 
63
  // Output workspace for cross-warp reduction: head_size * sizeof(float)
64
  size_t output_size = head_size * sizeof(float);
 
65
  return std::max(logits_size + reduction_size, output_size);
66
  }
67
 
 
123
  // Calculate shared memory requirements
124
  size_t shared_memory_size = calculateSharedMemorySize(max_seq_len, head_size, num_threads, num_simd_lanes);
125
 
126
+ // Get kernel name - v1 kernels have partition_size=0 in their name
127
  std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
128
+ head_size, block_size, num_threads, num_simd_lanes, partition_size);
129
 
130
  @autoreleasepool {
131
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
 
157
  error ? error.localizedDescription.UTF8String : "unknown error");
158
 
159
  // Setup command buffer and encoder
160
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
161
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
162
+
163
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
164
  TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
165
 
166
+ dispatch_queue_t q = stream->queue();
167
  dispatch_sync(q, ^{
168
  id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
169
  TORCH_CHECK(enc, "Failed to create compute command encoder");
 
253
  [enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup];
254
  [enc endEncoding];
255
 
256
+ stream->synchronize(at::mps::SyncType::COMMIT);
257
  });
258
  }
259
  }
 
314
  // Get kernel names
315
  std::string kernel_name = getKernelName("paged_attention", query.scalar_type(),
316
  head_size, block_size, num_threads, num_simd_lanes, partition_size);
317
+ // Reduce kernel doesn't have block_size in its name
318
+ std::string reduce_kernel_name = "paged_attention_v2_reduce";
319
+ switch (query.scalar_type()) {
320
+ case torch::kFloat: reduce_kernel_name += "_float"; break;
321
+ case torch::kHalf: reduce_kernel_name += "_half"; break;
322
+ case torch::kBFloat16: reduce_kernel_name += "_bfloat16_t"; break;
323
+ default:
324
+ TORCH_CHECK(false, "Unsupported dtype for paged attention: ", query.scalar_type());
325
+ }
326
+ reduce_kernel_name += "_hs" + std::to_string(head_size) +
327
+ "_nt" + std::to_string(num_threads) +
328
+ "_nsl" + std::to_string(num_simd_lanes) +
329
+ "_ps" + std::to_string(partition_size);
330
 
331
  @autoreleasepool {
332
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
 
342
  error ? error.localizedDescription.UTF8String : "unknown error");
343
 
344
  // Setup command buffer and queue
345
+ at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
346
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
347
+
348
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
349
  TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
350
 
351
+ dispatch_queue_t q = stream->queue();
352
  dispatch_sync(q, ^{
353
  // ==================================================================
354
  // Phase 1: Main paged attention kernel with partitioning
 
526
  [reduceEnc dispatchThreadgroups:reduceGrid threadsPerThreadgroup:reduceThreadgroup];
527
  [reduceEnc endEncoding];
528
 
529
+ stream->synchronize(at::mps::SyncType::COMMIT);
530
  });
531
  }
532
  }
tests/kernels/test_attention.py CHANGED
@@ -228,7 +228,7 @@ def test_paged_attention(
228
  64,
229
  0,
230
  ),
231
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
232
  )
233
 
234
  elif version in ("v2", "rocm"):
@@ -291,7 +291,7 @@ def test_paged_attention(
291
  64,
292
  0,
293
  ),
294
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
295
  )
296
 
297
  else:
@@ -336,7 +336,7 @@ def test_paged_attention(
336
  k_scale,
337
  v_scale,
338
  ),
339
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
340
  )
341
 
342
  else:
 
228
  64,
229
  0,
230
  ),
231
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
232
  )
233
 
234
  elif version in ("v2", "rocm"):
 
291
  64,
292
  0,
293
  ),
294
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
295
  )
296
 
297
  else:
 
336
  k_scale,
337
  v_scale,
338
  ),
339
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
340
  )
341
 
342
  else:
torch-ext/torch_binding.cpp CHANGED
@@ -15,103 +15,103 @@
15
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
16
 
17
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
18
- // Attention ops
19
- // Compute the attention between an input query and the cached
20
- // keys/values using PagedAttention.
21
- ops.def(
22
- "paged_attention_v1("
23
- " Tensor! out, Tensor query, Tensor key_cache,"
24
- " Tensor value_cache, int num_kv_heads, float scale,"
25
- " Tensor block_tables, Tensor seq_lens, int block_size,"
26
- " int max_seq_len, Tensor? alibi_slopes,"
27
- " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
28
- " int tp_rank, int blocksparse_local_blocks,"
29
- " int blocksparse_vert_stride, int blocksparse_block_size,"
30
- " int blocksparse_head_sliding_step) -> ()");
31
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
32
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
33
  #elif defined(METAL_KERNEL)
34
  ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1);
35
  #endif
36
 
37
- // PagedAttention V2.
38
- ops.def(
39
- "paged_attention_v2("
40
- " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
41
- " Tensor! tmp_out, Tensor query, Tensor key_cache,"
42
- " Tensor value_cache, int num_kv_heads, float scale,"
43
- " Tensor block_tables, Tensor seq_lens, int block_size,"
44
- " int max_seq_len, Tensor? alibi_slopes,"
45
- " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
46
- " int tp_rank, int blocksparse_local_blocks,"
47
- " int blocksparse_vert_stride, int blocksparse_block_size,"
48
- " int blocksparse_head_sliding_step) -> ()");
49
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
50
  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
51
  #elif defined(METAL_KERNEL)
52
  ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2);
53
  #endif
54
 
55
- // Swap in (out) the cache blocks from src to dst.
56
- ops.def(
57
- "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
58
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
59
  ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
60
  #elif defined(METAL_KERNEL)
61
  ops.impl("swap_blocks", torch::kMPS, swap_blocks);
62
  #endif
63
 
64
- // Copy the cache blocks from src to dst.
65
- ops.def(
66
- "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
67
- "Tensor block_mapping) -> ()");
68
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
69
  ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
70
  #elif defined(METAL_KERNEL)
71
  ops.impl("copy_blocks", torch::kMPS, copy_blocks);
72
  #endif
73
 
74
- // Reshape the key and value tensors and cache them.
75
- ops.def(
76
- "reshape_and_cache(Tensor key, Tensor value,"
77
- " Tensor! key_cache, Tensor! value_cache,"
78
- " Tensor slot_mapping,"
79
- " str kv_cache_dtype,"
80
- " Tensor k_scale, Tensor v_scale) -> ()");
81
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
82
  ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
83
  #elif defined(METAL_KERNEL)
84
  ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache);
85
  #endif
86
 
87
- // Reshape the key and value tensors and cache them.
88
- ops.def(
89
- "reshape_and_cache_flash(Tensor key, Tensor value,"
90
- " Tensor! key_cache,"
91
- " Tensor! value_cache,"
92
- " Tensor slot_mapping,"
93
- " str kv_cache_dtype,"
94
- " Tensor k_scale, Tensor v_scale) -> ()");
95
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
96
  ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);
97
  #elif defined(METAL_KERNEL)
98
  ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash);
99
  #endif
100
 
101
- // Gets the specified device attribute.
102
- ops.def("get_device_attribute(int attribute, int device_id) -> int");
103
- ops.impl("get_device_attribute", &get_device_attribute);
104
 
105
- // Gets the maximum shared memory per block device attribute.
106
- ops.def(
107
- "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
108
- ops.impl("get_max_shared_memory_per_block_device_attribute",
109
- &get_max_shared_memory_per_block_device_attribute);
110
 
111
- // Convert the key and value cache to fp8 data type.
112
- ops.def(
113
- "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
114
- "str kv_cache_dtype) -> ()");
115
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
116
  ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
117
  #elif defined(METAL_KERNEL)
 
15
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
16
 
17
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
18
+ // Attention ops
19
+ // Compute the attention between an input query and the cached
20
+ // keys/values using PagedAttention.
21
+ ops.def(
22
+ "paged_attention_v1("
23
+ " Tensor! out, Tensor query, Tensor key_cache,"
24
+ " Tensor value_cache, int num_kv_heads, float scale,"
25
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
26
+ " int max_seq_len, Tensor? alibi_slopes,"
27
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
28
+ " int tp_rank, int blocksparse_local_blocks,"
29
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
30
+ " int blocksparse_head_sliding_step) -> ()");
31
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
32
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
33
  #elif defined(METAL_KERNEL)
34
  ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1);
35
  #endif
36
 
37
+ // PagedAttention V2.
38
+ ops.def(
39
+ "paged_attention_v2("
40
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
41
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
42
+ " Tensor value_cache, int num_kv_heads, float scale,"
43
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
44
+ " int max_seq_len, Tensor? alibi_slopes,"
45
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
46
+ " int tp_rank, int blocksparse_local_blocks,"
47
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
48
+ " int blocksparse_head_sliding_step) -> ()");
49
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
50
  ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
51
  #elif defined(METAL_KERNEL)
52
  ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2);
53
  #endif
54
 
55
+ // Swap in (out) the cache blocks from src to dst.
56
+ ops.def(
57
+ "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
58
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
59
  ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
60
  #elif defined(METAL_KERNEL)
61
  ops.impl("swap_blocks", torch::kMPS, swap_blocks);
62
  #endif
63
 
64
+ // Copy the cache blocks from src to dst.
65
+ ops.def(
66
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
67
+ "Tensor block_mapping) -> ()");
68
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
69
  ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
70
  #elif defined(METAL_KERNEL)
71
  ops.impl("copy_blocks", torch::kMPS, copy_blocks);
72
  #endif
73
 
74
+ // Reshape the key and value tensors and cache them.
75
+ ops.def(
76
+ "reshape_and_cache(Tensor key, Tensor value,"
77
+ " Tensor! key_cache, Tensor! value_cache,"
78
+ " Tensor slot_mapping,"
79
+ " str kv_cache_dtype,"
80
+ " Tensor k_scale, Tensor v_scale) -> ()");
81
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
82
  ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
83
  #elif defined(METAL_KERNEL)
84
  ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache);
85
  #endif
86
 
87
+ // Reshape the key and value tensors and cache them.
88
+ ops.def(
89
+ "reshape_and_cache_flash(Tensor key, Tensor value,"
90
+ " Tensor! key_cache,"
91
+ " Tensor! value_cache,"
92
+ " Tensor slot_mapping,"
93
+ " str kv_cache_dtype,"
94
+ " Tensor k_scale, Tensor v_scale) -> ()");
95
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
96
  ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);
97
  #elif defined(METAL_KERNEL)
98
  ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash);
99
  #endif
100
 
101
+ // Gets the specified device attribute.
102
+ ops.def("get_device_attribute(int attribute, int device_id) -> int");
103
+ ops.impl("get_device_attribute", &get_device_attribute);
104
 
105
+ // Gets the maximum shared memory per block device attribute.
106
+ ops.def(
107
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
108
+ ops.impl("get_max_shared_memory_per_block_device_attribute",
109
+ &get_max_shared_memory_per_block_device_attribute);
110
 
111
+ // Convert the key and value cache to fp8 data type.
112
+ ops.def(
113
+ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
114
+ "str kv_cache_dtype) -> ()");
115
  #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
116
  ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
117
  #elif defined(METAL_KERNEL)