#pragma once #include #include #include std::tuple s2_attention_bwd_dkvq_cuda( at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor dy, at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out ); torch::Tensor s2_attention_fwd_cuda( at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, at::Tensor psi_col_idx, at::Tensor psi_row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out );