File size: 1,334 Bytes
0940df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <torch/extension.h>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <ATen/core/op_registration/op_registration.h>

// CUDA declarations

void mf_fused_forward_cuda(
    at::Tensor max_offset,
    at::Tensor q,
    at::Tensor k,
    at::Tensor v,
    at::Tensor output,
    at::Tensor attn_out,
    const int H,
    const int W,
    const std::vector<int64_t>& win_r,
    const int attn_num,
    const std::string& attn_type,
    const float scale);

void mf_fused_forward(
    at::Tensor max_offset,
    at::Tensor q,
    at::Tensor k,
    at::Tensor v,
    at::Tensor output,
    at::Tensor attn_out,
    const int64_t H,
    const int64_t W,
    const std::vector<int64_t>& win_r,
    const int64_t attn_num,
    const std::string& attn_type,
    const double scale)
{
    mf_fused_forward_cuda(max_offset, q, k, v, output, attn_out, H, W, win_r, attn_num, attn_type, static_cast<float>(scale));
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("fused_forward", &mf_fused_forward, "Fused forward pass (CUDA)");
}

TORCH_LIBRARY(match_attention, m)
{
    m.def("fused_forward(Tensor max_offset, Tensor q, Tensor k, Tensor v, Tensor(a!) output, Tensor(b!) attn_out, int H, int W, int[] win_r, int attn_num, str attn_type, float scale) -> ()", &mf_fused_forward);
}