| #pragma once |
|
|
| #include <torch/torch.h> |
| #include <optional> |
| #include <tuple> |
| #include <vector> |
|
|
| |
| std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> |
| sparse_decode_fwd( |
| const at::Tensor &q, |
| const at::Tensor &kv, |
| const at::Tensor &indices, |
| const std::optional<at::Tensor> &topk_length, |
| const std::optional<at::Tensor> &attn_sink, |
| const std::optional<at::Tensor> &tile_scheduler_metadata, |
| const std::optional<at::Tensor> &num_splits, |
| const std::optional<at::Tensor> &extra_kv, |
| const std::optional<at::Tensor> &extra_indices, |
| const std::optional<at::Tensor> &extra_topk_length, |
| int64_t d_v, |
| double sm_scale |
| ); |
|
|
| |
| std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> |
| dense_decode_fwd( |
| at::Tensor q, |
| const at::Tensor &kcache, |
| int64_t head_size_v, |
| const at::Tensor &seqlens_k, |
| const at::Tensor &block_table, |
| double softmax_scale, |
| bool is_causal, |
| const std::optional<at::Tensor> &tile_scheduler_metadata, |
| const std::optional<at::Tensor> &num_splits |
| ); |
|
|
| |
| std::vector<at::Tensor> sparse_prefill_fwd( |
| const at::Tensor &q, |
| const at::Tensor &kv, |
| const at::Tensor &indices, |
| double sm_scale, |
| int64_t d_v, |
| const std::optional<at::Tensor> &attn_sink, |
| const std::optional<at::Tensor> &topk_length |
| ); |
|
|
| |
| void dense_prefill_fwd( |
| at::Tensor workspace_buffer, |
| at::Tensor q, |
| at::Tensor k, |
| at::Tensor v, |
| at::Tensor cumulative_seqlen_q, |
| at::Tensor cumulative_seqlen_kv, |
| at::Tensor o, |
| at::Tensor lse, |
| int64_t mask_mode_code, |
| double softmax_scale, |
| int64_t max_seqlen_q, |
| int64_t max_seqlen_kv, |
| bool is_varlen |
| ); |
|
|
| |
| void dense_prefill_bwd( |
| at::Tensor workspace_buffer, |
| at::Tensor d_o, |
| at::Tensor q, |
| at::Tensor k, |
| at::Tensor v, |
| at::Tensor o, |
| at::Tensor lse, |
| at::Tensor cumulative_seqlen_q, |
| at::Tensor cumulative_seqlen_kv, |
| at::Tensor dq, |
| at::Tensor dk, |
| at::Tensor dv, |
| int64_t mask_mode_code, |
| double softmax_scale, |
| int64_t max_seqlen_q, |
| int64_t max_seqlen_kv, |
| bool is_varlen |
| ); |
|
|
| |
|
|
| std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> |
| sparse_attn_decode_interface( |
| const at::Tensor &q, |
| const at::Tensor &kv, |
| const at::Tensor &indices, |
| const std::optional<at::Tensor> &topk_length, |
| const std::optional<at::Tensor> &attn_sink, |
| std::optional<at::Tensor> &tile_scheduler_metadata, |
| std::optional<at::Tensor> &num_splits, |
| const std::optional<at::Tensor> &extra_kv, |
| const std::optional<at::Tensor> &extra_indices, |
| const std::optional<at::Tensor> &extra_topk_length, |
| int d_v, |
| float sm_scale |
| ); |
|
|
| std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>> |
| dense_attn_decode_interface( |
| at::Tensor &q, |
| const at::Tensor &kcache, |
| const int head_size_v, |
| const at::Tensor &seqlens_k, |
| const at::Tensor &block_table, |
| const float softmax_scale, |
| bool is_causal, |
| std::optional<at::Tensor> &tile_scheduler_metadata, |
| std::optional<at::Tensor> &num_splits |
| ); |
|
|
| std::vector<at::Tensor> sparse_attn_prefill_interface( |
| const at::Tensor &q, |
| const at::Tensor &kv, |
| const at::Tensor &indices, |
| float sm_scale, |
| int d_v, |
| const std::optional<at::Tensor> &attn_sink, |
| const std::optional<at::Tensor> &topk_length |
| ); |
|
|
| |
| |
| static void FMHACutlassSM100FwdRun( |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| int , |
| float , |
| int , |
| int , |
| bool |
| ) { |
| TORCH_CHECK(false, "dense_prefill_fwd requires SM100 (Blackwell) GPU and CUDA 12.9+. This build only supports SM90 (Hopper)."); |
| } |
|
|
| static void FMHACutlassSM100BwdRun( |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| at::Tensor , |
| int , |
| float , |
| int , |
| int , |
| bool |
| ) { |
| TORCH_CHECK(false, "dense_prefill_bwd requires SM100 (Blackwell) GPU and CUDA 12.9+. This build only supports SM90 (Hopper)."); |
| } |
|
|