#include #include "registration.h" #include "torch_binding.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)"); ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda); ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor"); ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)