File size: 740 Bytes
84ec9f0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"


TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
    ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)");
    ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda);
    ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor");
    ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)