medmekk HF Staff commited on
Commit
84ec9f0
·
verified ·
1 Parent(s): 57f64dd

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. README.md +3 -0
  3. build.toml +35 -0
  4. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  5. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  6. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  7. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  8. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  9. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  10. build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  11. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  12. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  13. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  14. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  15. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  16. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  17. build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  18. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  19. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  20. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  21. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  22. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  23. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  24. build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  25. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  26. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  27. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  28. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  29. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  30. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  31. build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  32. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  33. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  34. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  35. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  36. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  37. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  38. build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  39. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__init__.py +10 -0
  40. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc +0 -0
  41. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc +0 -0
  42. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc +0 -0
  43. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_attn_utils.py +637 -0
  44. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_ops.py +9 -0
  45. build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +3 -0
  46. flake.nix +13 -0
  47. nix-build.log +0 -0
  48. torch-ext/torch_binding.cpp +14 -0
  49. torch-ext/torch_binding.h +31 -0
  50. torch-ext/torch_harmonics_attn/__init__.py +10 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
37
+ build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
38
+ build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
39
+ build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
40
+ build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
41
+ build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Torch Harmonics Attn
2
+
3
+ Attention mechanisms for the Spherical Harmonics basis using the torch-harmonics package : https://github.com/NVIDIA/torch-harmonics/tree/main/torch_harmonics/attention
build.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "torch_harmonics_attn"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h",
9
+ ]
10
+
11
+ [kernel.torch_harmonics_attn]
12
+ depends = ["torch"]
13
+ backend = "cuda"
14
+ cuda-capabilities = [
15
+ "7.5",
16
+ "8.0",
17
+ "8.9",
18
+ "9.0",
19
+ "10.0",
20
+ ]
21
+ src = [
22
+ "torch_harmonics_attn/attention_cpu_bwd.cpp",
23
+ "torch_harmonics_attn/attention_cpu_fwd.cpp",
24
+ "torch_harmonics_attn/attention_cpu.h",
25
+
26
+ "torch_harmonics_attn/attention_cuda_bwd.cu",
27
+ "torch_harmonics_attn/attention_cuda_fwd.cu",
28
+ "torch_harmonics_attn/attention_cuda_utils.cu",
29
+ "torch_harmonics_attn/attention_cuda_utils.cuh",
30
+ "torch_harmonics_attn/attention_cuda.cuh",
31
+
32
+ "torch_harmonics_attn/attention.h",
33
+ "torch_harmonics_attn/cudamacro.h"
34
+ ]
35
+
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4e9bb69e777ace94e18326ea2559292b3c0fbb11d68b185c1c4d700767ebf68
3
+ size 27631360
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a01d03d3f594f42388c5627a59cb8976d3e2fbb5f2adf76c4d5a5dc3f295d35a
3
+ size 27689536
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe35cb08c5705c56860da606c3b5480ef7880deaeb42eb0efcd4a37ef1bd70d6
3
+ size 35370448
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a1f5426e6d758a776dab4a8ccd4abecbf516f0c53d9884b44746cf5585898af
3
+ size 27627336
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3f834671fd44bea1d2e3cd23d4f99f5cb61ec7822b028830000b358f70797fe
3
+ size 35321056
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (436 Bytes). View file
 
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc ADDED
Binary file (27.2 kB). View file
 
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc ADDED
Binary file (570 Bytes). View file
 
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_attn_utils.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+
32
+ from typing import Union, Tuple
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from ._ops import ops
38
+
39
+ def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
40
+ return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
41
+
42
+ def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out):
43
+ return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out)
44
+
45
+ def _setup_context_attention_backward(ctx, inputs, output):
46
+ k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs
47
+ ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq)
48
+ ctx.nh = nh
49
+ ctx.max_psi_nnz = max_psi_nnz
50
+ ctx.nlon_in = nlon_in
51
+ ctx.nlat_out = nlat_out
52
+ ctx.nlon_out = nlon_out
53
+
54
+ def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor,
55
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
56
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
57
+ out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out)
58
+ return torch.empty(out_shape, dtype=kw.dtype, device=kw.device)
59
+
60
+ def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor,
61
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
62
+ nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ dk = torch.empty_like(kw)
64
+ dv = torch.empty_like(vw)
65
+ dq = torch.empty_like(qw)
66
+ return dk, dv, dq
67
+
68
+ # forward
69
+ def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
70
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
71
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
72
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
73
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
74
+
75
+ kw = F.conv2d(k, weight=wk, bias=bk)
76
+ vw = F.conv2d(v, weight=wv, bias=bv)
77
+ qw = F.conv2d(q, weight=wq, bias=bq)
78
+
79
+ # reshape, folding num heads into batch dim
80
+ B, _, H, W = kw.shape
81
+ kw = kw.reshape(B*nh, -1, H, W)
82
+ B, _, H, W = vw.shape
83
+ vw = vw.reshape(B*nh, -1, H, W)
84
+ B, _, H, W = qw.shape
85
+ qw = qw.reshape(B*nh, -1, H, W)
86
+
87
+ # convert to float32
88
+ inp_dtype = kw.dtype
89
+ kw = kw.to(torch.float32).contiguous()
90
+ vw = vw.to(torch.float32).contiguous()
91
+ qw = qw.to(torch.float32).contiguous()
92
+
93
+ output = forward(kw, vw, qw, quad_weights,
94
+ col_idx, row_off,
95
+ nlon_in, nlat_out, nlon_out)
96
+
97
+ _, C, H, W = output.shape
98
+ output = output.reshape(B, -1, H, W)
99
+
100
+ # convert back precision
101
+ output = output.to(dtype=inp_dtype)
102
+
103
+ return output
104
+
105
+ def backward_optimized(ctx, grad_output):
106
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
107
+ nh = ctx.nh
108
+ max_psi_nnz = ctx.max_psi_nnz
109
+ nlon_in = ctx.nlon_in
110
+ nlat_out = ctx.nlat_out
111
+ nlon_out = ctx.nlon_out
112
+
113
+ # check if we need the grads at all
114
+ k_needs_grad = ctx.needs_input_grad[0]
115
+ v_needs_grad = ctx.needs_input_grad[1]
116
+ q_needs_grad = ctx.needs_input_grad[2]
117
+ wk_needs_grad = ctx.needs_input_grad[3]
118
+ wv_needs_grad = ctx.needs_input_grad[4]
119
+ wq_needs_grad = ctx.needs_input_grad[5]
120
+ bk_needs_grad = ctx.needs_input_grad[6]
121
+ bv_needs_grad = ctx.needs_input_grad[7]
122
+ bq_needs_grad = ctx.needs_input_grad[8]
123
+
124
+ kw = F.conv2d(k, weight=wk, bias=bk)
125
+ vw = F.conv2d(v, weight=wv, bias=bv)
126
+ qw = F.conv2d(q, weight=wq, bias=bq)
127
+
128
+ # reshape, folding num heads into batch dim
129
+ B, _, H, W = kw.shape
130
+ kw = kw.reshape(B*nh, -1, H, W)
131
+ B, _, H, W = vw.shape
132
+ vw = vw.reshape(B*nh, -1, H, W)
133
+ B, _, H, W = qw.shape
134
+ qw = qw.reshape(B*nh, -1, H, W)
135
+ B, _, H, W = grad_output.shape
136
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
137
+
138
+ # save type and convert to float32
139
+ kw_dtype = kw.dtype
140
+ vw_dtype = vw.dtype
141
+ qw_dtype = qw.dtype
142
+
143
+ kw = kw.to(torch.float32).contiguous()
144
+ vw = vw.to(torch.float32).contiguous()
145
+ qw = qw.to(torch.float32).contiguous()
146
+ grad_output = grad_output.to(torch.float32).contiguous()
147
+
148
+ dkw, dvw, dqw = backward(kw, vw, qw, grad_output,
149
+ quad_weights,
150
+ col_idx, row_off,
151
+ nlon_in, nlat_out, nlon_out)
152
+
153
+ # weight grads
154
+ _, C, H, W = dkw.shape
155
+ dkw = dkw.reshape(B, -1, H, W)
156
+ dkw = dkw.to(dtype=kw_dtype)
157
+ if wk_needs_grad:
158
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
159
+ else:
160
+ dwk = None
161
+
162
+ _, C, H, W = dvw.shape
163
+ dvw = dvw.reshape(B, -1, H, W)
164
+ dvw = dvw.to(dtype=vw_dtype)
165
+ if wv_needs_grad:
166
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
167
+ else:
168
+ dwv = None
169
+
170
+ _, C, H, W = dqw.shape
171
+ dqw = dqw.reshape(B, -1, H, W)
172
+ dqw = dqw.to(dtype=qw_dtype)
173
+ if wq_needs_grad:
174
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
175
+ else:
176
+ dwq = None
177
+
178
+ # input grads
179
+ if v_needs_grad:
180
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
181
+ else:
182
+ dv = None
183
+
184
+ if k_needs_grad:
185
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
186
+ else:
187
+ dk = None
188
+
189
+ if q_needs_grad:
190
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
191
+ else:
192
+ dq = None
193
+
194
+ # bias grads:
195
+ if bv_needs_grad:
196
+ dbv = torch.sum(dvw, dim=(0,2,3))
197
+ else:
198
+ dbv = None
199
+
200
+ if bk_needs_grad:
201
+ dbk = torch.sum(dkw, dim=(0,2,3))
202
+ else:
203
+ dbk = None
204
+
205
+ if bq_needs_grad:
206
+ dbq = torch.sum(dqw, dim=(0,2,3))
207
+ else:
208
+ dbq = None
209
+
210
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
211
+ None, None, None, None, None, None, None, None
212
+
213
+ # torch kernels
214
+ # uses qdotk_max update trick to avoid two loops when computing the softmax
215
+ # see e.g., https://arxiv.org/abs/1805.02867
216
+ # and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
217
+ def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor,
218
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
219
+ nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
220
+
221
+
222
+ # prepare result tensor
223
+ out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out)
224
+ y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device)
225
+
226
+ for ho in range(nlat_out):
227
+
228
+ # get number of nonzeros
229
+ zstart = row_off[ho]
230
+ zend = row_off[ho+1]
231
+
232
+ for wo in range(nlon_out):
233
+
234
+ alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
235
+ qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device)
236
+
237
+ for idz in range(zstart, zend):
238
+ nz_col_idx = col_idx[idz]
239
+
240
+ # compute input indices from psi datastructure
241
+ hi = nz_col_idx // nlon_in
242
+ # account for output shift and ensure positive index due to circular condition
243
+ wi = nz_col_idx % nlon_in
244
+ wip = (wi + wo) % nlon_in
245
+
246
+ # compute correlation & softmax numerator
247
+ q_ho_wo = qy[:, :, ho, wo]
248
+ k_hi_wip = kx[:, :, hi, wip]
249
+ qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1)
250
+
251
+ # tmp max
252
+ qdotk_max_tmp = torch.maximum(qdotk_max, qdotk)
253
+
254
+ # alpha sum update
255
+ alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi]
256
+ alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp)
257
+ # update output
258
+ y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip]
259
+
260
+ # define new max
261
+ qdotk_max = qdotk_max_tmp
262
+
263
+ y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None]
264
+
265
+ return y
266
+
267
+ # Explicit gradient w.r.t. vx: dM/dv
268
+ # provided as a reference for CUDA & other hand-written gradients
269
+ def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
270
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
271
+ nlon_in: int, nlat_out: int, nlon_out: int):
272
+
273
+ # shapes:
274
+ # input
275
+ # kx: B, C, Hi, Wi
276
+ # vx: B, Cout, Hi, Wi
277
+ # qy: B, Cout, Ho, Wo
278
+ # quad_weights: Hi
279
+ # output
280
+ # dvx: B, Cout, Hi, Wi
281
+
282
+ dvx = torch.zeros_like(vx)
283
+ batch_size = dy.shape[0]
284
+
285
+ for ho in range(nlat_out):
286
+
287
+ # get number of nonzeros
288
+ zstart = row_off[ho]
289
+ zend = row_off[ho+1]
290
+
291
+ for wo in range(nlon_out):
292
+
293
+ alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
294
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
295
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
296
+ for idz in range(zstart, zend):
297
+ nz_col_idx = col_idx[idz]
298
+
299
+ # compute input indices from psi datastructure
300
+ hi = nz_col_idx // nlon_in
301
+ # account for output shift and ensure positive index due to circular condition
302
+ wi = nz_col_idx % nlon_in
303
+ wip = (wi+wo) % nlon_in
304
+
305
+ # compute correlation & softmax numerator
306
+ q_ho_wo = qy[:, :, ho, wo]
307
+ k_hi_wi = kx[:, :, hi, wip]
308
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
309
+
310
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
311
+
312
+ for idz in range(zstart, zend):
313
+ nz_col_idx = col_idx[idz]
314
+
315
+ # compute input indices from psi datastructure
316
+ hi = nz_col_idx // nlon_in
317
+ # account for output shift and ensure positive index due to circular condition
318
+ wi = nz_col_idx % nlon_in
319
+ wip = (wi+wo) % nlon_in
320
+ alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
321
+ alpha_sum[:] += alpha_nz[:,idz-zstart]
322
+
323
+ for idz in range(zstart, zend):
324
+ nz_col_idx = col_idx[idz]
325
+
326
+ # compute input indices from psi datastructure
327
+ hi = nz_col_idx // nlon_in
328
+ # account for output shift and ensure positive index due to circular condition
329
+ wi = nz_col_idx % nlon_in
330
+ wip = (wi+wo) % nlon_in
331
+ dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo]
332
+
333
+ return dvx
334
+
335
+
336
+ # Explicit gradient w.r.t. kx: dM/dk
337
+ # provided as a reference for CUDA & other hand-written gradients
338
+ def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
339
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
340
+ nlon_in: int, nlat_out: int, nlon_out: int):
341
+
342
+ # shapes:
343
+ # input
344
+ # kx: B, C, Hi, Wi
345
+ # vx: B, Cout, Hi, Wi
346
+ # qy: B, C, Ho, Wo
347
+ # quad_weights: Hi
348
+ # output
349
+ # dkx: B, C, Hi, Wi
350
+
351
+ dkx = torch.zeros_like(kx)
352
+ batch_size = dy.shape[0]
353
+
354
+ for ho in range(nlat_out):
355
+
356
+ # get number of nonzeros
357
+ zstart = row_off[ho]
358
+ zend = row_off[ho+1]
359
+
360
+ for wo in range(nlon_out):
361
+
362
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
363
+ integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
364
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
365
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
366
+ for idz in range(zstart, zend):
367
+ nz_col_idx = col_idx[idz]
368
+
369
+ # compute input indices from psi datastructure
370
+ hj = nz_col_idx // nlon_in
371
+ # account for output shift and ensure positive index due to circular condition
372
+ wj = nz_col_idx % nlon_in
373
+ wjp = (wj+wo) % nlon_in
374
+
375
+ # compute correlation & softmax numerator
376
+ q_ho_wo = qy[:, :, ho, wo]
377
+ k_hj_wjp = kx[:, :, hj, wjp]
378
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1)
379
+
380
+ qdotk_max, _ = torch.max(qdotk_nz, dim=1)
381
+
382
+ for idz in range(zstart, zend):
383
+ nz_col_idx = col_idx[idz]
384
+
385
+ # compute input indices from psi datastructure
386
+ hj = nz_col_idx // nlon_in
387
+ # account for output shift and ensure positive index due to circular condition
388
+ wj = nz_col_idx % nlon_in
389
+ wjp = (wj+wo) % nlon_in
390
+
391
+ alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj]
392
+ alpha_sum[:] += alpha[:, idz-zstart]
393
+
394
+ # input dot
395
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1)
396
+
397
+ # integral term
398
+ integral[:] += alpha[:, idz-zstart] * gdotv[:]
399
+
400
+ integral[:] = integral[:] / alpha_sum[:]
401
+
402
+ for idz in range(zstart, zend):
403
+ nz_col_idx = col_idx[idz]
404
+
405
+ # compute input indices from psi datastructure
406
+ hi = nz_col_idx // nlon_in
407
+ # account for output shift and ensure positive index due to circular condition
408
+ wi = nz_col_idx % nlon_in
409
+ wip = (wi+wo) % nlon_in
410
+
411
+ # compute correlation & softmax numerator
412
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
413
+
414
+ dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None])
415
+
416
+ return dkx
417
+
418
+ # Explicit gradient w.r.t. qy: dM/dq
419
+ # provided as a reference for CUDA & other hand-written gradients
420
+ def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor,
421
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
422
+ nlon_in: int, nlat_out: int, nlon_out: int):
423
+
424
+ # shapes:
425
+ # input
426
+ # kx: B, C, Hi, Wi
427
+ # vx: B, Cout, Hi, Wi
428
+ # qy: B, C, Ho, Wo
429
+ # quad_weights: Hi
430
+ # output
431
+ # dq: B, C, Ho, Wo
432
+
433
+ batch_size = dy.shape[0]
434
+ channels_in = kx.shape[1]
435
+ channels_out = vx.shape[1]
436
+
437
+ dqy = torch.zeros_like(qy)
438
+
439
+ for ho in range(nlat_out):
440
+
441
+ # get number of nonzeros
442
+ zstart = row_off[ho]
443
+ zend = row_off[ho+1]
444
+
445
+ for wo in range(nlon_out):
446
+
447
+ alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
448
+ qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device)
449
+ alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
450
+ alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
451
+ alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device)
452
+ alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
453
+ alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device)
454
+ for idz in range(zstart, zend):
455
+ nz_col_idx = col_idx[idz]
456
+
457
+ # compute input indices from psi datastructure
458
+ hi = nz_col_idx // nlon_in
459
+ # account for output shift and ensure positive index due to circular condition
460
+ wi = nz_col_idx % nlon_in
461
+ wip = (wi+wo) % nlon_in
462
+
463
+ idz_i = idz-zstart
464
+
465
+ # compute correlation & softmax numerator
466
+ q_ho_wo = qy[:, :, ho, wo]
467
+ k_hi_wi = kx[:, :, hi, wip]
468
+ qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1)
469
+
470
+ qdotk_max,_ = qdotk_nz.max(dim=1)
471
+
472
+ for idz in range(zstart, zend):
473
+ nz_col_idx = col_idx[idz]
474
+
475
+ # compute input indices from psi datastructure
476
+ hi = nz_col_idx // nlon_in
477
+ # account for output shift and ensure positive index due to circular condition
478
+ wi = nz_col_idx % nlon_in
479
+ wip = (wi+wo) % nlon_in
480
+
481
+ q_ho_wo = qy[:, :, ho, wo]
482
+ k_hi_wi = kx[:, :, hi, wip]
483
+ idz_i = idz-zstart
484
+ alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi]
485
+ alpha_sum[:] += alpha[:, idz_i]
486
+
487
+ gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1)
488
+ alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi
489
+ alpha_vw[:] += alpha[:, idz_i] * gdotv[:]
490
+ alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None]
491
+
492
+ dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None])
493
+
494
+ return dqy
495
+
496
+ def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor,
497
+ wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor,
498
+ bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None],
499
+ quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor,
500
+ max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor:
501
+ kw = F.conv2d(k, weight=wk, bias=bk)
502
+ vw = F.conv2d(v, weight=wv, bias=bv)
503
+ qw = F.conv2d(q, weight=wq, bias=bq)
504
+
505
+ # reshape, folding num heads into batch dim
506
+ B, _, H, W = kw.shape
507
+ kw = kw.reshape(B*nh, -1, H, W)
508
+ B, _, H, W = vw.shape
509
+ vw = vw.reshape(B*nh, -1, H, W)
510
+ B, _, H, W = qw.shape
511
+ qw = qw.reshape(B*nh, -1, H, W)
512
+
513
+ kw = kw.to(torch.float32)
514
+ vw = vw.to(torch.float32)
515
+ qw = qw.to(torch.float32)
516
+
517
+ output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights,
518
+ col_idx, row_off,
519
+ nlon_in, nlat_out, nlon_out)
520
+
521
+ _, C, H, W = output.shape
522
+ output = output.reshape(B, -1, H, W)
523
+
524
+ return output
525
+
526
+ def _neighborhood_s2_attention_bwd_torch(ctx, grad_output):
527
+ col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors
528
+ nh = ctx.nh
529
+ nlon_in = ctx.nlon_in
530
+ nlat_out = ctx.nlat_out
531
+ nlon_out = ctx.nlon_out
532
+
533
+ # check if we need the grads at all
534
+ k_needs_grad = ctx.needs_input_grad[0]
535
+ v_needs_grad = ctx.needs_input_grad[1]
536
+ q_needs_grad = ctx.needs_input_grad[2]
537
+ wk_needs_grad = ctx.needs_input_grad[3]
538
+ wv_needs_grad = ctx.needs_input_grad[4]
539
+ wq_needs_grad = ctx.needs_input_grad[5]
540
+ bk_needs_grad = ctx.needs_input_grad[6]
541
+ bv_needs_grad = ctx.needs_input_grad[7]
542
+ bq_needs_grad = ctx.needs_input_grad[8]
543
+
544
+ kw = F.conv2d(k, weight=wk, bias=bk)
545
+ vw = F.conv2d(v, weight=wv, bias=bv)
546
+ qw = F.conv2d(q, weight=wq, bias=bq)
547
+
548
+ # reshape, folding num heads into batch dim
549
+ B, _, H, W = kw.shape
550
+ kw = kw.reshape(B*nh, -1, H, W)
551
+ B, _, H, W = vw.shape
552
+ vw = vw.reshape(B*nh, -1, H, W)
553
+ B, _, H, W = qw.shape
554
+ qw = qw.reshape(B*nh, -1, H, W)
555
+ B, _, H, W = grad_output.shape
556
+ grad_output = grad_output.reshape(B*nh, -1, H, W)
557
+
558
+ if v_needs_grad or wv_needs_grad or bv_needs_grad:
559
+ dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output,
560
+ quad_weights,
561
+ col_idx, row_off,
562
+ nlon_in, nlat_out, nlon_out)
563
+ _, C, H, W = dvw.shape
564
+ dvw = dvw.reshape(B, -1, H, W)
565
+ else:
566
+ dvw = None
567
+
568
+ if k_needs_grad or wk_needs_grad or bk_needs_grad:
569
+ dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output,
570
+ quad_weights,
571
+ col_idx, row_off,
572
+ nlon_in, nlat_out, nlon_out)
573
+ _, C, H, W = dkw.shape
574
+ dkw = dkw.reshape(B, -1, H, W)
575
+ else:
576
+ dkw = None
577
+
578
+ if q_needs_grad or wq_needs_grad or bq_needs_grad:
579
+ dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output,
580
+ quad_weights,
581
+ col_idx, row_off,
582
+ nlon_in, nlat_out, nlon_out)
583
+ _, C, H, W = dqw.shape
584
+ dqw = dqw.reshape(B, -1, H, W)
585
+ else:
586
+ dqw = None
587
+
588
+ # input grads
589
+ if v_needs_grad:
590
+ dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None)
591
+ else:
592
+ dv = None
593
+
594
+ if k_needs_grad:
595
+ dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None)
596
+ else:
597
+ dk = None
598
+
599
+ if q_needs_grad:
600
+ dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None)
601
+ else:
602
+ dq = None
603
+
604
+ # weight grads
605
+ if wv_needs_grad:
606
+ dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous()
607
+ else:
608
+ dwv = None
609
+
610
+ if wk_needs_grad:
611
+ dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous()
612
+ else:
613
+ dwk = None
614
+
615
+ if wq_needs_grad:
616
+ dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous()
617
+ else:
618
+ dwq = None
619
+
620
+ # bias grads:
621
+ if bv_needs_grad:
622
+ dbv = torch.sum(dvw, dim=(0,2,3))
623
+ else:
624
+ dbv = None
625
+
626
+ if bk_needs_grad:
627
+ dbk = torch.sum(dkw, dim=(0,2,3))
628
+ else:
629
+ dbk = None
630
+
631
+ if bq_needs_grad:
632
+ dbq = torch.sum(dqw, dim=(0,2,3))
633
+ else:
634
+ dbq = None
635
+
636
+ return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \
637
+ None, None, None, None, None, None, None, None
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _torch_harmonics_attn_20251001150033
3
+ ops = torch.ops._torch_harmonics_attn_20251001150033
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_torch_harmonics_attn_20251001150033::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1e4408020fb8b28578efcad9e4f0358b96e643c9e9c18bd5d4e589112d94d84
3
+ size 34089304
flake.nix ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs = { self, kernel-builder, }:
9
+ kernel-builder.lib.genFlakeOutputs {
10
+ path = ./.;
11
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
12
+ };
13
+ }
nix-build.log ADDED
The diff for this file is too large to render. See raw diff
 
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+
7
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
8
+ ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)");
9
+ ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda);
10
+ ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor");
11
+ ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda);
12
+ }
13
+
14
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+ #include <cstdint>
5
+ #include <tuple>
6
+
7
+
8
+ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(
9
+ at::Tensor kx,
10
+ at::Tensor vx,
11
+ at::Tensor qy,
12
+ at::Tensor dy,
13
+ at::Tensor quad_weights,
14
+ at::Tensor psi_col_idx,
15
+ at::Tensor psi_row_off,
16
+ int64_t nlon_in,
17
+ int64_t nlat_out,
18
+ int64_t nlon_out
19
+ );
20
+
21
+ torch::Tensor s2_attention_fwd_cuda(
22
+ at::Tensor kx,
23
+ at::Tensor vx,
24
+ at::Tensor qy,
25
+ at::Tensor quad_weights,
26
+ at::Tensor psi_col_idx,
27
+ at::Tensor psi_row_off,
28
+ int64_t nlon_in,
29
+ int64_t nlat_out,
30
+ int64_t nlon_out
31
+ );
torch-ext/torch_harmonics_attn/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch
2
+
3
+ __all__ = [
4
+ "backward",
5
+ "forward",
6
+ "forward_optimized",
7
+ "backward_optimized",
8
+ "_neighborhood_s2_attention_fwd_torch",
9
+ "_neighborhood_s2_attention_bwd_torch",
10
+ ]