| #pragma once |
|
|
| #include <torch/torch.h> |
|
|
| void paged_attention_v1( |
| torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, |
| torch::Tensor& value_cache, int64_t num_kv_heads, double scale, |
| torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, |
| int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, |
| const std::string& kv_cache_dtype, torch::Tensor& k_scale, |
| torch::Tensor& v_scale, const int64_t tp_rank, |
| const int64_t blocksparse_local_blocks, |
| const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, |
| const int64_t blocksparse_head_sliding_step); |
|
|
| void paged_attention_v2( |
| torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, |
| torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, |
| torch::Tensor& value_cache, int64_t num_kv_heads, double scale, |
| torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, |
| int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, |
| const std::string& kv_cache_dtype, torch::Tensor& k_scale, |
| torch::Tensor& v_scale, const int64_t tp_rank, |
| const int64_t blocksparse_local_blocks, |
| const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, |
| const int64_t blocksparse_head_sliding_step); |
|
|
| void swap_blocks(torch::Tensor& src, torch::Tensor& dst, |
| const torch::Tensor& block_mapping); |
|
|
| |
| |
| |
| void copy_blocks(std::vector<torch::Tensor> const& key_caches, |
| std::vector<torch::Tensor> const& value_caches, |
| const torch::Tensor& block_mapping); |
|
|
| void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, |
| torch::Tensor& key_cache, torch::Tensor& value_cache, |
| torch::Tensor& slot_mapping, |
| const std::string& kv_cache_dtype, |
| torch::Tensor& k_scale, torch::Tensor& v_scale); |
|
|
| void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, |
| torch::Tensor& key_cache, |
| torch::Tensor& value_cache, |
| torch::Tensor& slot_mapping, |
| const std::string& kv_cache_dtype, |
| torch::Tensor& k_scale, torch::Tensor& v_scale); |
|
|
| int64_t get_device_attribute(int64_t attribute, int64_t device_id); |
|
|
| int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); |
|
|
| void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, |
| const double scale, const std::string& kv_cache_dtype); |
|
|