| std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda( | |
| at::Tensor kx, | |
| at::Tensor vx, | |
| at::Tensor qy, | |
| at::Tensor dy, | |
| at::Tensor quad_weights, | |
| at::Tensor psi_col_idx, | |
| at::Tensor psi_row_off, | |
| int64_t nlon_in, | |
| int64_t nlat_out, | |
| int64_t nlon_out | |
| ); | |
| torch::Tensor s2_attention_fwd_cuda( | |
| at::Tensor kx, | |
| at::Tensor vx, | |
| at::Tensor qy, | |
| at::Tensor quad_weights, | |
| at::Tensor psi_col_idx, | |
| at::Tensor psi_row_off, | |
| int64_t nlon_in, | |
| int64_t nlat_out, | |
| int64_t nlon_out | |
| ); |