Kernels
File size: 1,006 Bytes
8cf888e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"

/*
 * Register flash_attention_forward as a Torch operator.
 *
 * The function schema follows ATen conventions:
 *   - Tensor! out  : mutated output (pre-allocated by caller)
 *   - Tensor q/k/v : inputs
 *   - bool causal  : whether to use a causal mask
 *
 * The Triton kernel is launched from the Python side via the __init__.py
 * wrapper, so the C++ body here simply validates shapes and delegates.
 *
 * NOTE: For a pure-Triton kernel you may not need a .cu file at all; the
 * C++ binding exists only so that the op is reachable via torch.ops.<name>.
 */
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
    ops.def(
        "flash_attention_forward("
        "  Tensor! out,"
        "  Tensor q,"
        "  Tensor k,"
        "  Tensor v,"
        "  bool causal"
        ") -> ()"
    );
    ops.impl("flash_attention_forward", torch::kCUDA, &flash_attention_forward);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)