File size: 512 Bytes
0940df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#ifndef _MATCH_FORMER_FUSED_FORWARD_HPP_
#define _MATCH_FORMER_FUSED_FORWARD_HPP_

#include <vector>
#include <string>

// Fused forward function that combines all match former operations
void match_former_fused_forward(
    at::Tensor max_offset,
    at::Tensor q,
    at::Tensor k,
    at::Tensor v,
    at::Tensor output,
    at::Tensor attn_out,
    const int H,
    const int W,
    const std::vector<int64_t>& win_r,
    const int attn_num,
    const std::string& attn_type,
    const float scale);

#endif