File size: 6,085 Bytes
ccef021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | #include <torch/library.h>
#include <ATen/Tensor.h>
#include "registration.h"
#include "torch_binding.h"
// Wrapper functions that adapt the interface for TORCH_LIBRARY
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
sparse_decode_fwd(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
const std::optional<at::Tensor> &topk_length,
const std::optional<at::Tensor> &attn_sink,
const std::optional<at::Tensor> &tile_scheduler_metadata,
const std::optional<at::Tensor> &num_splits,
const std::optional<at::Tensor> &extra_kv,
const std::optional<at::Tensor> &extra_indices,
const std::optional<at::Tensor> &extra_topk_length,
int64_t d_v,
double sm_scale
) {
// Create mutable copies for the interface that modifies these
std::optional<at::Tensor> tile_scheduler_metadata_mut = tile_scheduler_metadata;
std::optional<at::Tensor> num_splits_mut = num_splits;
return sparse_attn_decode_interface(
q, kv, indices, topk_length, attn_sink,
tile_scheduler_metadata_mut, num_splits_mut,
extra_kv, extra_indices, extra_topk_length,
static_cast<int>(d_v), static_cast<float>(sm_scale)
);
}
std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>>
dense_decode_fwd(
at::Tensor q,
const at::Tensor &kcache,
int64_t head_size_v,
const at::Tensor &seqlens_k,
const at::Tensor &block_table,
double softmax_scale,
bool is_causal,
const std::optional<at::Tensor> &tile_scheduler_metadata,
const std::optional<at::Tensor> &num_splits
) {
// Create mutable copies for the interface that modifies these
std::optional<at::Tensor> tile_scheduler_metadata_mut = tile_scheduler_metadata;
std::optional<at::Tensor> num_splits_mut = num_splits;
return dense_attn_decode_interface(
q, kcache, static_cast<int>(head_size_v),
seqlens_k, block_table,
static_cast<float>(softmax_scale), is_causal,
tile_scheduler_metadata_mut, num_splits_mut
);
}
std::vector<at::Tensor> sparse_prefill_fwd(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
double sm_scale,
int64_t d_v,
const std::optional<at::Tensor> &attn_sink,
const std::optional<at::Tensor> &topk_length
) {
return sparse_attn_prefill_interface(
q, kv, indices,
static_cast<float>(sm_scale), static_cast<int>(d_v),
attn_sink, topk_length
);
}
void dense_prefill_fwd(
at::Tensor workspace_buffer,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv,
at::Tensor o,
at::Tensor lse,
int64_t mask_mode_code,
double softmax_scale,
int64_t max_seqlen_q,
int64_t max_seqlen_kv,
bool is_varlen
) {
FMHACutlassSM100FwdRun(
workspace_buffer, q, k, v,
cumulative_seqlen_q, cumulative_seqlen_kv,
o, lse,
static_cast<int>(mask_mode_code), static_cast<float>(softmax_scale),
static_cast<int>(max_seqlen_q), static_cast<int>(max_seqlen_kv),
is_varlen
);
}
void dense_prefill_bwd(
at::Tensor workspace_buffer,
at::Tensor d_o,
at::Tensor q,
at::Tensor k,
at::Tensor v,
at::Tensor o,
at::Tensor lse,
at::Tensor cumulative_seqlen_q,
at::Tensor cumulative_seqlen_kv,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
int64_t mask_mode_code,
double softmax_scale,
int64_t max_seqlen_q,
int64_t max_seqlen_kv,
bool is_varlen
) {
FMHACutlassSM100BwdRun(
workspace_buffer, d_o, q, k, v, o, lse,
cumulative_seqlen_q, cumulative_seqlen_kv,
dq, dk, dv,
static_cast<int>(mask_mode_code), static_cast<float>(softmax_scale),
static_cast<int>(max_seqlen_q), static_cast<int>(max_seqlen_kv),
is_varlen
);
}
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Sparse decode forward
ops.def(
"sparse_decode_fwd(Tensor q, Tensor kv, Tensor indices, "
"Tensor? topk_length, Tensor? attn_sink, "
"Tensor? tile_scheduler_metadata, Tensor? num_splits, "
"Tensor? extra_kv, Tensor? extra_indices, Tensor? extra_topk_length, "
"int d_v, float sm_scale) -> (Tensor, Tensor, Tensor?, Tensor?)"
);
ops.impl("sparse_decode_fwd", torch::kCUDA, &sparse_decode_fwd);
// Dense decode forward
ops.def(
"dense_decode_fwd(Tensor q, Tensor kcache, int head_size_v, "
"Tensor seqlens_k, Tensor block_table, "
"float softmax_scale, bool is_causal, "
"Tensor? tile_scheduler_metadata, Tensor? num_splits) -> (Tensor, Tensor, Tensor?, Tensor?)"
);
ops.impl("dense_decode_fwd", torch::kCUDA, &dense_decode_fwd);
// Sparse prefill forward
ops.def(
"sparse_prefill_fwd(Tensor q, Tensor kv, Tensor indices, "
"float sm_scale, int d_v, "
"Tensor? attn_sink, Tensor? topk_length) -> Tensor[]"
);
ops.impl("sparse_prefill_fwd", torch::kCUDA, &sparse_prefill_fwd);
// Dense prefill forward (SM100)
ops.def(
"dense_prefill_fwd(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, "
"Tensor cumulative_seqlen_q, Tensor cumulative_seqlen_kv, "
"Tensor o, Tensor lse, "
"int mask_mode_code, float softmax_scale, "
"int max_seqlen_q, int max_seqlen_kv, bool is_varlen) -> ()"
);
ops.impl("dense_prefill_fwd", torch::kCUDA, &dense_prefill_fwd);
// Dense prefill backward (SM100)
ops.def(
"dense_prefill_bwd(Tensor workspace_buffer, Tensor d_o, "
"Tensor q, Tensor k, Tensor v, Tensor o, Tensor lse, "
"Tensor cumulative_seqlen_q, Tensor cumulative_seqlen_kv, "
"Tensor dq, Tensor dk, Tensor dv, "
"int mask_mode_code, float softmax_scale, "
"int max_seqlen_q, int max_seqlen_kv, bool is_varlen) -> ()"
);
ops.impl("dense_prefill_bwd", torch::kCUDA, &dense_prefill_bwd);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|