#include #include #include "registration.h" #include "torch_binding.h" // Wrapper functions that adapt the interface for TORCH_LIBRARY std::tuple, std::optional> sparse_decode_fwd( const at::Tensor &q, const at::Tensor &kv, const at::Tensor &indices, const std::optional &topk_length, const std::optional &attn_sink, const std::optional &tile_scheduler_metadata, const std::optional &num_splits, const std::optional &extra_kv, const std::optional &extra_indices, const std::optional &extra_topk_length, int64_t d_v, double sm_scale ) { // Create mutable copies for the interface that modifies these std::optional tile_scheduler_metadata_mut = tile_scheduler_metadata; std::optional num_splits_mut = num_splits; return sparse_attn_decode_interface( q, kv, indices, topk_length, attn_sink, tile_scheduler_metadata_mut, num_splits_mut, extra_kv, extra_indices, extra_topk_length, static_cast(d_v), static_cast(sm_scale) ); } std::tuple, std::optional> dense_decode_fwd( at::Tensor q, const at::Tensor &kcache, int64_t head_size_v, const at::Tensor &seqlens_k, const at::Tensor &block_table, double softmax_scale, bool is_causal, const std::optional &tile_scheduler_metadata, const std::optional &num_splits ) { // Create mutable copies for the interface that modifies these std::optional tile_scheduler_metadata_mut = tile_scheduler_metadata; std::optional num_splits_mut = num_splits; return dense_attn_decode_interface( q, kcache, static_cast(head_size_v), seqlens_k, block_table, static_cast(softmax_scale), is_causal, tile_scheduler_metadata_mut, num_splits_mut ); } std::vector sparse_prefill_fwd( const at::Tensor &q, const at::Tensor &kv, const at::Tensor &indices, double sm_scale, int64_t d_v, const std::optional &attn_sink, const std::optional &topk_length ) { return sparse_attn_prefill_interface( q, kv, indices, static_cast(sm_scale), static_cast(d_v), attn_sink, topk_length ); } void dense_prefill_fwd( at::Tensor workspace_buffer, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor o, at::Tensor lse, int64_t mask_mode_code, double softmax_scale, int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_varlen ) { FMHACutlassSM100FwdRun( workspace_buffer, q, k, v, cumulative_seqlen_q, cumulative_seqlen_kv, o, lse, static_cast(mask_mode_code), static_cast(softmax_scale), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), is_varlen ); } void dense_prefill_bwd( at::Tensor workspace_buffer, at::Tensor d_o, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor o, at::Tensor lse, at::Tensor cumulative_seqlen_q, at::Tensor cumulative_seqlen_kv, at::Tensor dq, at::Tensor dk, at::Tensor dv, int64_t mask_mode_code, double softmax_scale, int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_varlen ) { FMHACutlassSM100BwdRun( workspace_buffer, d_o, q, k, v, o, lse, cumulative_seqlen_q, cumulative_seqlen_kv, dq, dk, dv, static_cast(mask_mode_code), static_cast(softmax_scale), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), is_varlen ); } TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Sparse decode forward ops.def( "sparse_decode_fwd(Tensor q, Tensor kv, Tensor indices, " "Tensor? topk_length, Tensor? attn_sink, " "Tensor? tile_scheduler_metadata, Tensor? num_splits, " "Tensor? extra_kv, Tensor? extra_indices, Tensor? extra_topk_length, " "int d_v, float sm_scale) -> (Tensor, Tensor, Tensor?, Tensor?)" ); ops.impl("sparse_decode_fwd", torch::kCUDA, &sparse_decode_fwd); // Dense decode forward ops.def( "dense_decode_fwd(Tensor q, Tensor kcache, int head_size_v, " "Tensor seqlens_k, Tensor block_table, " "float softmax_scale, bool is_causal, " "Tensor? tile_scheduler_metadata, Tensor? num_splits) -> (Tensor, Tensor, Tensor?, Tensor?)" ); ops.impl("dense_decode_fwd", torch::kCUDA, &dense_decode_fwd); // Sparse prefill forward ops.def( "sparse_prefill_fwd(Tensor q, Tensor kv, Tensor indices, " "float sm_scale, int d_v, " "Tensor? attn_sink, Tensor? topk_length) -> Tensor[]" ); ops.impl("sparse_prefill_fwd", torch::kCUDA, &sparse_prefill_fwd); // Dense prefill forward (SM100) ops.def( "dense_prefill_fwd(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, " "Tensor cumulative_seqlen_q, Tensor cumulative_seqlen_kv, " "Tensor o, Tensor lse, " "int mask_mode_code, float softmax_scale, " "int max_seqlen_q, int max_seqlen_kv, bool is_varlen) -> ()" ); ops.impl("dense_prefill_fwd", torch::kCUDA, &dense_prefill_fwd); // Dense prefill backward (SM100) ops.def( "dense_prefill_bwd(Tensor workspace_buffer, Tensor d_o, " "Tensor q, Tensor k, Tensor v, Tensor o, Tensor lse, " "Tensor cumulative_seqlen_q, Tensor cumulative_seqlen_kv, " "Tensor dq, Tensor dk, Tensor dv, " "int mask_mode_code, float softmax_scale, " "int max_seqlen_q, int max_seqlen_kv, bool is_varlen) -> ()" ); ops.impl("dense_prefill_bwd", torch::kCUDA, &dense_prefill_bwd); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)