| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)"); | |
| ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda); | |
| ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor"); | |
| ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda); | |
| } | |
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |