| // TODO: Add all of the functions listed | |
| // PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| // m.doc() = "FlashAttention"; | |
| // m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); | |
| // m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)"); | |
| // m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); | |
| // m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)"); | |
| // m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); | |
| // } | |
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]"); | |
| ops.impl("mha_fwd", torch::kCUDA, &mha_fwd); | |
| } | |
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) | |