torch_harmonics_attn / torch-ext /torch_binding.cpp
medmekk's picture
medmekk HF Staff
Upload folder using huggingface_hub
84ec9f0 verified
raw
history blame contribute delete
740 Bytes
#include <torch/library.h>
#include "registration.h"
#include "torch_binding.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)");
ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda);
ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor");
ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)