| | #pragma once |
| |
|
| | #include <torch/torch.h> |
| |
|
| | |
| | std::vector<torch::Tensor> |
| | mha_fwd( |
| | torch::Tensor &q, |
| | const torch::Tensor &k, |
| | const torch::Tensor &v, |
| | c10::optional<torch::Tensor> out_,\ |
| | c10::optional<torch::Tensor> alibi_slopes_, |
| | const double p_dropout, |
| | const double softmax_scale, |
| | bool is_causal, |
| | const int64_t window_size_left, |
| | const int64_t window_size_right, |
| | const double softcap, |
| | const bool return_softmax, |
| | c10::optional<at::Generator> gen_); |
| |
|
| | std::vector<torch::Tensor> |
| | mha_varlen_fwd( |
| | at::Tensor &q, |
| | const torch::Tensor &k, |
| | const torch::Tensor &v, |
| | c10::optional<torch::Tensor> out_, |
| | const torch::Tensor &cu_seqlens_q, |
| | const torch::Tensor &cu_seqlens_k, |
| | c10::optional<torch::Tensor> seqused_k, |
| | |
| | c10::optional<torch::Tensor> leftpad_k_, |
| | c10::optional<torch::Tensor> block_table_, |
| | c10::optional<torch::Tensor> alibi_slopes_, |
| | int64_t max_seqlen_q, |
| | const int64_t max_seqlen_k, |
| | const double p_dropout, |
| | const double softmax_scale, |
| | const bool zero_tensors, |
| | bool is_causal, |
| | int64_t window_size_left, |
| | int64_t window_size_right, |
| | const double softcap, |
| | const bool return_softmax, |
| | std::optional<at::Generator> gen_); |
| |
|
| | std::vector<torch::Tensor> |
| | mha_bwd(const torch::Tensor &dout, |
| | const torch::Tensor &q, |
| | const torch::Tensor &k, |
| | const torch::Tensor &v, |
| | const torch::Tensor &out, |
| | const torch::Tensor &softmax_lse, |
| | const c10::optional<torch::Tensor> &dq_, |
| | const c10::optional<torch::Tensor> &dk_, |
| | const c10::optional<torch::Tensor> &dv_, |
| | const c10::optional<torch::Tensor> &alibi_slopes_, |
| | const double p_dropout, |
| | const double softmax_scale, |
| | const bool is_causal, |
| | const int64_t window_size_left, |
| | const int64_t window_size_right, |
| | const double softcap, |
| | const bool deterministic, |
| | c10::optional<at::Generator> gen_, |
| | const c10::optional<torch::Tensor> &rng_state); |
| |
|
| |
|
| | std::vector<torch::Tensor> |
| | mha_varlen_bwd( |
| | const torch::Tensor &dout, |
| | const torch::Tensor &q, |
| | const torch::Tensor &k, |
| | const torch::Tensor &v, |
| | const torch::Tensor &out, |
| | const torch::Tensor &softmax_lse, |
| | const c10::optional<torch::Tensor> &dq_, |
| | const c10::optional<torch::Tensor> &dk_, |
| | const c10::optional<torch::Tensor> &dv_, |
| | const torch::Tensor &cu_seqlens_q, |
| | const torch::Tensor &cu_seqlens_k, |
| | const c10::optional<torch::Tensor> &alibi_slopes_, |
| | const int64_t max_seqlen_q, |
| | const int64_t max_seqlen_k, |
| | const double p_dropout, |
| | const double softmax_scale, |
| | const bool zero_tensors, |
| | const bool is_causal, |
| | const int64_t window_size_left, |
| | const int64_t window_size_right, |
| | const double softcap, |
| | const bool deterministic, |
| | c10::optional<at::Generator> gen_, |
| | const c10::optional<torch::Tensor> &rng_state); |
| |
|
| | std::vector<torch::Tensor> |
| | mha_fwd_kvcache( |
| | const torch::Tensor &q, |
| | const torch::Tensor &kcache, |
| | const torch::Tensor &vcache, |
| | const c10::optional<torch::Tensor> &k_, |
| | const c10::optional<torch::Tensor> &v_, |
| | const c10::optional<torch::Tensor> &seqlens_k_, |
| | const c10::optional<torch::Tensor> &rotary_cos_, |
| | const c10::optional<torch::Tensor> &rotary_sin_, |
| | const c10::optional<torch::Tensor> &cache_batch_idx_, |
| | const c10::optional<torch::Tensor> &leftpad_k_, |
| | const c10::optional<torch::Tensor> &block_table_, |
| | const c10::optional<torch::Tensor> &alibi_slopes_, |
| | const c10::optional<torch::Tensor> &out_, |
| | const double softmax_scale, |
| | bool is_causal, |
| | const int64_t window_size_left, |
| | const int64_t window_size_right, |
| | const double softcap, |
| | bool is_rotary_interleaved, |
| | const int64_t num_splits); |
| |
|