| #include <torch/library.h> |
|
|
| #include "registration.h" |
|
|
| #include "torch_binding.h" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| |
| |
| |
| ops.def( |
| "paged_attention_v1(" |
| " Tensor! out, Tensor query, Tensor key_cache," |
| " Tensor value_cache, int num_kv_heads, float scale," |
| " Tensor block_tables, Tensor seq_lens, int block_size," |
| " int max_seq_len, Tensor? alibi_slopes," |
| " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," |
| " int tp_rank, int blocksparse_local_blocks," |
| " int blocksparse_vert_stride, int blocksparse_block_size," |
| " int blocksparse_head_sliding_step) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); |
| #elif defined(METAL_KERNEL) |
| ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1); |
| #endif |
|
|
| |
| ops.def( |
| "paged_attention_v2(" |
| " Tensor! out, Tensor! exp_sums, Tensor! max_logits," |
| " Tensor! tmp_out, Tensor query, Tensor key_cache," |
| " Tensor value_cache, int num_kv_heads, float scale," |
| " Tensor block_tables, Tensor seq_lens, int block_size," |
| " int max_seq_len, Tensor? alibi_slopes," |
| " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," |
| " int tp_rank, int blocksparse_local_blocks," |
| " int blocksparse_vert_stride, int blocksparse_block_size," |
| " int blocksparse_head_sliding_step) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); |
| #elif defined(METAL_KERNEL) |
| ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2); |
| #endif |
|
|
| |
| ops.def( |
| "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); |
| #elif defined(METAL_KERNEL) |
| ops.impl("swap_blocks", torch::kMPS, swap_blocks); |
| #endif |
|
|
| |
| ops.def( |
| "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " |
| "Tensor block_mapping) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("copy_blocks", torch::kCUDA, ©_blocks); |
| #elif defined(METAL_KERNEL) |
| ops.impl("copy_blocks", torch::kMPS, copy_blocks); |
| #endif |
|
|
| |
| ops.def( |
| "reshape_and_cache(Tensor key, Tensor value," |
| " Tensor! key_cache, Tensor! value_cache," |
| " Tensor slot_mapping," |
| " str kv_cache_dtype," |
| " Tensor k_scale, Tensor v_scale) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); |
| #elif defined(METAL_KERNEL) |
| ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache); |
| #endif |
|
|
| |
| ops.def( |
| "reshape_and_cache_flash(Tensor key, Tensor value," |
| " Tensor! key_cache," |
| " Tensor! value_cache," |
| " Tensor slot_mapping," |
| " str kv_cache_dtype," |
| " Tensor k_scale, Tensor v_scale) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); |
| #elif defined(METAL_KERNEL) |
| ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash); |
| #endif |
|
|
| |
| ops.def("get_device_attribute(int attribute, int device_id) -> int"); |
| ops.impl("get_device_attribute", &get_device_attribute); |
|
|
| |
| ops.def( |
| "get_max_shared_memory_per_block_device_attribute(int device_id) -> int"); |
| ops.impl("get_max_shared_memory_per_block_device_attribute", |
| &get_max_shared_memory_per_block_device_attribute); |
|
|
| |
| ops.def( |
| "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " |
| "str kv_cache_dtype) -> ()"); |
| #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL) |
| ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); |
| #elif defined(METAL_KERNEL) |
| ops.impl("convert_fp8", torch::kMPS, convert_fp8); |
| #endif |
| } |
|
|
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
|
|