Kernels
flash-attention-1-triton / torch-ext /torch_binding.cpp
sigmoid-neuron's picture
feat: implement Flash Attention 1 using Triton with PyTorch C++ bindings and test suite
8cf888e
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
/*
* Register flash_attention_forward as a Torch operator.
*
* The function schema follows ATen conventions:
* - Tensor! out : mutated output (pre-allocated by caller)
* - Tensor q/k/v : inputs
* - bool causal : whether to use a causal mask
*
* The Triton kernel is launched from the Python side via the __init__.py
* wrapper, so the C++ body here simply validates shapes and delegates.
*
* NOTE: For a pure-Triton kernel you may not need a .cu file at all; the
* C++ binding exists only so that the op is reachable via torch.ops.<name>.
*/
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"flash_attention_forward("
" Tensor! out,"
" Tensor q,"
" Tensor k,"
" Tensor v,"
" bool causal"
") -> ()"
);
ops.impl("flash_attention_forward", torch::kCUDA, &flash_attention_forward);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)