#include #include "registration.h" #include "torch_binding.h" /* * Register flash_attention_forward as a Torch operator. * * The function schema follows ATen conventions: * - Tensor! out : mutated output (pre-allocated by caller) * - Tensor q/k/v : inputs * - bool causal : whether to use a causal mask * * The Triton kernel is launched from the Python side via the __init__.py * wrapper, so the C++ body here simply validates shapes and delegates. * * NOTE: For a pure-Triton kernel you may not need a .cu file at all; the * C++ binding exists only so that the op is reachable via torch.ops.. */ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "flash_attention_forward(" " Tensor! out," " Tensor q," " Tensor k," " Tensor v," " bool causal" ") -> ()" ); ops.impl("flash_attention_forward", torch::kCUDA, &flash_attention_forward); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)