diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..fc6a6c836d04da79b0393d1d0f7f2a73ad77df36 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4c476139f94bc06a61b1be679f2b1a39b9fafc54 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +## Torch Harmonics Attn + +Attention mechanisms for the Spherical Harmonics basis using the torch-harmonics package : https://github.com/NVIDIA/torch-harmonics/tree/main/torch_harmonics/attention \ No newline at end of file diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..c5f1e5086b9571f519876972ed827624629af15b --- /dev/null +++ b/build.toml @@ -0,0 +1,35 @@ +[general] +name = "torch_harmonics_attn" +universal = false + +[torch] +src = [ + "torch-ext/torch_binding.cpp", + "torch-ext/torch_binding.h", +] + +[kernel.torch_harmonics_attn] +depends = ["torch"] +backend = "cuda" +cuda-capabilities = [ + "7.5", + "8.0", + "8.9", + "9.0", + "10.0", +] +src = [ + "torch_harmonics_attn/attention_cpu_bwd.cpp", + "torch_harmonics_attn/attention_cpu_fwd.cpp", + "torch_harmonics_attn/attention_cpu.h", + + "torch_harmonics_attn/attention_cuda_bwd.cu", + "torch_harmonics_attn/attention_cuda_fwd.cu", + "torch_harmonics_attn/attention_cuda_utils.cu", + "torch_harmonics_attn/attention_cuda_utils.cuh", + "torch_harmonics_attn/attention_cuda.cuh", + + "torch_harmonics_attn/attention.h", + "torch_harmonics_attn/cudamacro.h" +] + diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af167ad29fc7219e3c72faff333c379c15c7853 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9d531db2b181b47a318c930ddbcad9639346406 Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..958819f274362997a8e5555c962f88a8febf780a Binary files /dev/null and b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..fd477fa93d379d26b48c5738daafbe757c5ca034 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4e9bb69e777ace94e18326ea2559292b3c0fbb11d68b185c1c4d700767ebf68 +size 27631360 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8ce5d6838a58748be9473ce0d623ece1e5438be Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04a4cdaa52c9fb0118a415f1c3623562328b691c Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..221929d898f58587a95ca80305d7d17ed3b908e1 Binary files /dev/null and b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c24d0b903c19e6a621e89442438a678ea20d131e --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a01d03d3f594f42388c5627a59cb8976d3e2fbb5f2adf76c4d5a5dc3f295d35a +size 27689536 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdfdce9a6dbcffb47d9aee778f606dd4fdf00e05 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..733f35f6e9c4f20df3801a1c64d59a954f412757 Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eccbce74c4a01e7a8330a6b9bccdb2f517b218e Binary files /dev/null and b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..705bfefedd65567de910ffdc3f45646f57c51ffd --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe35cb08c5705c56860da606c3b5480ef7880deaeb42eb0efcd4a37ef1bd70d6 +size 35370448 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfbc3c4ccd249f7f70921dab089d18d9f32ec5fd Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83cacb6d21a4b929e89141f8f8ad8cb9a332265b Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ba0d9a596fc93eef0f1432ddbdfc59a20507799 Binary files /dev/null and b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..28b6c92444825ee9a7e1af216e671e9e1244897d --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1a1f5426e6d758a776dab4a8ccd4abecbf516f0c53d9884b44746cf5585898af +size 27627336 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5a0bf692413dac0614428ab1dbed13e7b2de6a6 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c5ff3459c91cf736728d0ad64c5d83ce9a0370a Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beb670b1a757f06a98862d0d7d68d07258da1ad3 Binary files /dev/null and b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..18486bb8f626913497ab99b7b325313f675cc925 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e3f834671fd44bea1d2e3cd23d4f99f5cb61ec7822b028830000b358f70797fe +size 35321056 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__init__.py b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..231c1d286ba81b03892934b554205849c8f527b5 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75e1b3a7e19e7f0f83546f1232885447ab6cd488 Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_attn_utils.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674b478087b30b29750f013b378d155e980e1a3b Binary files /dev/null and b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_attn_utils.py b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6561d6ce1c84bdcc5d680e1090e2e5af4d955a --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _torch_harmonics_attn_20251001150033 +ops = torch.ops._torch_harmonics_attn_20251001150033 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_torch_harmonics_attn_20251001150033::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..7a29374bd45b1f97bf6638e020a982f7a6b296d5 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1e4408020fb8b28578efcad9e4f0358b96e643c9e9c18bd5d4e589112d94d84 +size 34089304 diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..ee6d68ec2b4e0ae4c7a9139538f9706cb31b1d2e --- /dev/null +++ b/flake.nix @@ -0,0 +1,13 @@ +{ + description = "Flake for Torch kernel extension"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = { self, kernel-builder, }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/nix-build.log b/nix-build.log new file mode 100644 index 0000000000000000000000000000000000000000..5bfe676c6aa2e48c377280fb4aac0e679bca1ac5 --- /dev/null +++ b/nix-build.log @@ -0,0 +1,860 @@ +warning: not writing modified lock file of flake 'path:/home/ec2-user/dev/torch_harmonics_attn': +• Added input 'kernel-builder': + 'github:huggingface/kernel-builder/437d0f5c253a78d0be8b5998d9c1fcf32ac2360c?narHash=sha256-RzjCEn0zDfdwQp4WAb0BBuLlHxypr%2B4%2Ba4BMON23SNw%3D' (2025-10-01) +• Added input 'kernel-builder/flake-compat': + 'github:edolstra/flake-compat/9100a0f413b0c601e0533d1d94ffd501ce2e7885?narHash=sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX%2BfjA8Xf8PUmqCY%3D' (2025-05-12) +• Added input 'kernel-builder/flake-utils': + 'github:numtide/flake-utils/11707dc2f618dd54ca8739b309ec4fc024de578b?narHash=sha256-l0KFg5HjrsfsO/JpG%2Br7fRrqm12kzFHyUHqHCVpMMbI%3D' (2024-11-13) +• Added input 'kernel-builder/flake-utils/systems': + 'github:nix-systems/default/da67096a3b9bf56a91d16901293e51ba5b49a27e?narHash=sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768%3D' (2023-04-09) +• Added input 'kernel-builder/hf-nix': + 'github:huggingface/hf-nix/faf3354403a7381958d08e826c15fe30f6986a4f?narHash=sha256-JQKZOI1ZYO4faJnanuoTXziSmqzXe5rEFSGliWDWqWw%3D' (2025-09-12) +• Added input 'kernel-builder/hf-nix/flake-compat': + 'github:edolstra/flake-compat/9100a0f413b0c601e0533d1d94ffd501ce2e7885?narHash=sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX%2BfjA8Xf8PUmqCY%3D' (2025-05-12) +• Added input 'kernel-builder/hf-nix/flake-utils': + 'github:numtide/flake-utils/11707dc2f618dd54ca8739b309ec4fc024de578b?narHash=sha256-l0KFg5HjrsfsO/JpG%2Br7fRrqm12kzFHyUHqHCVpMMbI%3D' (2024-11-13) +• Added input 'kernel-builder/hf-nix/flake-utils/systems': + 'github:nix-systems/default/da67096a3b9bf56a91d16901293e51ba5b49a27e?narHash=sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768%3D' (2023-04-09) +• Added input 'kernel-builder/hf-nix/nixpkgs': + 'github:nixos/nixpkgs/73e96df7cff5783f45e21342a75a1540c4eddce4?narHash=sha256-6yD0ww/S8n%2BU2uPYcJZ3DRURP8Kx036GRpR2uPNZroE%3D' (2025-08-23) +• Added input 'kernel-builder/nixpkgs': + follows 'kernel-builder/hf-nix/nixpkgs' +evaluation warning: `rev` argument of `genFlakeOutputs` is deprecated, pass `self` as follows: + + kernel-builder.lib.genFlakeOutputs { + inherit self; + path = ./.; + }; +these 8 derivations will be built: + /nix/store/10rmsaxfc9kz090dlcdbwmmg6lc20rn9-torch_harmonics_attn-torch-ext.drv + /nix/store/119vi4cvkkcm3kya4pfg9497qx496hgn-torch_harmonics_attn-torch-ext.drv + /nix/store/6hkvq3na02vlyc59xhrhiyzxvy7yyx81-torch_harmonics_attn-torch-ext.drv + /nix/store/jgbsvp30czyym5dc3184gn0fxkq80zqz-torch_harmonics_attn-torch-ext.drv + /nix/store/pf1r215zajp82v4izkyfbfsvcipkmwfh-torch_harmonics_attn-torch-ext.drv + /nix/store/x11gxyjbl77gmj3i02481pzrakqa7rc9-torch_harmonics_attn-torch-ext.drv + /nix/store/nlrms1s9z75f22da29ac2vgmngzfdpay-torch-ext-bundle.drv + /nix/store/pg5s7j1lx0yqdbsxzkwm8jrr4bsls395-build-and-copy.drv +building '/nix/store/10rmsaxfc9kz090dlcdbwmmg6lc20rn9-torch_harmonics_attn-torch-ext.drv'... +building '/nix/store/119vi4cvkkcm3kya4pfg9497qx496hgn-torch_harmonics_attn-torch-ext.drv'... +building '/nix/store/6hkvq3na02vlyc59xhrhiyzxvy7yyx81-torch_harmonics_attn-torch-ext.drv'... +building '/nix/store/jgbsvp30czyym5dc3184gn0fxkq80zqz-torch_harmonics_attn-torch-ext.drv'... +building '/nix/store/pf1r215zajp82v4izkyfbfsvcipkmwfh-torch_harmonics_attn-torch-ext.drv'... +building '/nix/store/x11gxyjbl77gmj3i02481pzrakqa7rc9-torch_harmonics_attn-torch-ext.drv'... +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Sourcing get-kernel-check-hook.sh +torch_harmonics_attn-torch-ext> Sourcing setup-cuda-hook +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/strip -DCMAKE_RANLIB=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/ranlib -DCMAKE_AR=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/j6r6hpjs8p5m4s3i8cqqavg62fd5z48g-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/include\;/nix/store/5f6h6xs5c74iqcjda3y73i290mfwfs9x-cuda_nvml_dev-12.6.77-dev/include\;/nix/store/r26q9f2lhsvimxha44g1xcck3adrdqwg-cuda_nvrtc-12.6.85-dev/include\;/nix/store/nj1a061pvzpq9dr65yj3jpjqcx6pr4fq-cuda_nvtx-12.6.77-dev/include\;/nix/store/9ik1skjb698l6vkx4m4wvx2nrr4sx0na-libcufft-11.3.0.4-dev/include\;/nix/store/vl1dficb0blxzqg6xqzfi5p119jvl2vi-libcusolver-11.7.1.2-dev/include\;/nix/store/n7x9kkzi2jdfj6f6yjwywfhyfmn957zp-cuda_cupti-12.6.80-dev/include\;/nix/store/sskxmb670akk0avrahrl4r6hp7925zh8-cuda_cudart-12.6.77-dev/include\;/nix/store/8a9vz66yzsar01lpgipmzq8skyk3ymkp-cuda_cccl-12.6.77-dev/include\;/nix/store/xd2xrldv3lbg1bk93nr0yccy6j0vhh2k-cudnn-9.11.0.98-dev/include\;/nix/store/0w4g3rxgkw9r0lv737rslqdk7wldmi0n-libcurand-10.3.7.77-dev/include\;/nix/store/m0s4p867fk6wk8ba7ym9yff4mayqjhlw-libcusparse-12.5.4.2-dev/include\;/nix/store/blh9iyvjkmwd871mfjvfhnp7njwgnc6b-cuda_profiler_api-12.6.77-dev/include\;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev/include\;/nix/store/4pwy3k2s52ppzbs3k6d58kda8jhmiim4-libcufile-1.11.1.6-dev/include -DCUDAToolkit_ROOT=/nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85\;/nix/store/1qgrl2sgdj5m7llm2vs9690gd9998psq-cudnn-9.11.0.98\;/nix/store/d2z15dzsgfm4r2yyl16n3wc0sw8z6fia-cuda_cupti-12.6.80-lib\;/nix/store/86ngm5djfbl6a0i43j282680chqz1vr8-libcusparse-12.5.4.2-lib\;/nix/store/bmph9rbyqnyjs02zriwq78kg16h12wi6-libcublas-12.6.4.1-lib\;/nix/store/wny8xmyma0ziffas96ansxgmjfqpw393-cuda_nvrtc-12.6.85-lib\;/nix/store/j40ndiqjiqbiqrbfmgmkzz6w8757cgvk-cuda_nvml_dev-12.6.77-lib\;/nix/store/3ii532blh586xxavim32i21kr84wlcdc-cuda_profiler_api-12.6.77\;/nix/store/j32l8jnzckhdy2lzxgyd59y7p39y6b1d-libcusolver-11.7.1.2-static\;/nix/store/5iv2zpbf4k00ch4c5zfi5b8dlj90y3d3-cuda_cccl-12.6.77\;/nix/store/a8yi28jqv5185bbv10jpjja3x98i86hm-cuda_cudart-12.6.77-stubs\;/nix/store/ya85qn68jv6mlq6gh6phh5hwk3dkynag-cuda_cudart-12.6.77-static\;/nix/store/m65ribrsnk3gbabcx9ah6phgiil19j01-libcufile-1.11.1.6\;/nix/store/5f6h6xs5c74iqcjda3y73i290mfwfs9x-cuda_nvml_dev-12.6.77-dev\;/nix/store/r26q9f2lhsvimxha44g1xcck3adrdqwg-cuda_nvrtc-12.6.85-dev\;/nix/store/nj1a061pvzpq9dr65yj3jpjqcx6pr4fq-cuda_nvtx-12.6.77-dev\;/nix/store/bcvj4g3f3n6cpb6czcb5k8zdmyd94fwi-cuda_nvtx-12.6.77-lib\;/nix/store/9ik1skjb698l6vkx4m4wvx2nrr4sx0na-libcufft-11.3.0.4-dev\;/nix/store/k5rbpivsz3ilsxg91pgigp6la8ln3cv9-cuda_cupti-12.6.80\;/nix/store/vl1dficb0blxzqg6xqzfi5p119jvl2vi-libcusolver-11.7.1.2-dev\;/nix/store/f87x0n0gi2d7rxh1ja92za2ixcw60q2p-cuda_nvtx-12.6.77\;/nix/store/n7x9kkzi2jdfj6f6yjwywfhyfmn957zp-cuda_cupti-12.6.80-dev\;/nix/store/m0fwdgh4nmrjd0q9v4m2ly63qbcq2hi2-cuda_cudart-12.6.77\;/nix/store/qfaxx4b8l1alrrl0gbyb23k3j850c0v5-libcurand-10.3.7.77-static\;/nix/store/w1npzy8mfl28w7cib5idkg6nvlbzhpzq-libcufile-1.11.1.6-lib\;/nix/store/8abbm2gd77dv0l3acw0s18wln36aa0l5-cuda_cudart-12.6.77-lib\;/nix/store/ykb9bv2lqkf1wzy73q96cb04pybx9xa2-cuda_nvcc-12.6.85-static\;/nix/store/nw9ws2qvhgdb33qgfx4iqj517814qq8y-libcufft-11.3.0.4\;/nix/store/sskxmb670akk0avrahrl4r6hp7925zh8-cuda_cudart-12.6.77-dev\;/nix/store/mfc3ah6lwfd8dfbs77b0z9i75c471b0n-libcufft-11.3.0.4-static\;/nix/store/zk3cg1ws6cskrzyhdr5d68f8zrkfk77d-cuda_nvrtc-12.6.85-static\;/nix/store/pcrirrvn2ya5d3r1y18s2zj4pm2jladw-libcusolver-11.7.1.2\;/nix/store/qdn67x8jrwr418air16kwicya4d747pq-libcufft-11.3.0.4-lib\;/nix/store/dg8hyrzy7sh3wdhcr4ywsz05cvl6vfyc-libcusparse-12.5.4.2\;/nix/store/8a9vz66yzsar01lpgipmzq8skyk3ymkp-cuda_cccl-12.6.77-dev\;/nix/store/wmcrrdxd3db58nklyp7yf90kknfdx6b5-libcurand-10.3.7.77-lib\;/nix/store/xd2xrldv3lbg1bk93nr0yccy6j0vhh2k-cudnn-9.11.0.98-dev\;/nix/store/0w4g3rxgkw9r0lv737rslqdk7wldmi0n-libcurand-10.3.7.77-dev\;/nix/store/jr1397g6pshvil5n4lnvp7dm24dm71h8-libcublas-12.6.4.1-static\;/nix/store/wq0wv7df58h6bgggnz964sk8m1hbkxxp-cuda_cupti-12.6.80-sample\;/nix/store/m0s4p867fk6wk8ba7ym9yff4mayqjhlw-libcusparse-12.5.4.2-dev\;/nix/store/blh9iyvjkmwd871mfjvfhnp7njwgnc6b-cuda_profiler_api-12.6.77-dev\;/nix/store/ngwsphsxf906z7cgwg32d1w83p809ywl-cudnn-9.11.0.98-static\;/nix/store/07zlxn68jyf4s263xafnjid55grmi7a2-cuda_nvrtc-12.6.85\;/nix/store/zyh7hqq402zc7dhafhbh9vycyzcfq256-libcurand-10.3.7.77\;/nix/store/x7mww4k0zzzb7bnffv0b22jqbyf1mg3v-cuda_cupti-12.6.80-static\;/nix/store/xvlapjc6spss1kvbjlq97m6pk19hfrxz-cuda_nvml_dev-12.6.77\;/nix/store/7j4zf0r8flh7l4x5pm1mgqb2vcabmcdj-libcusolver-11.7.1.2-lib\;/nix/store/gs8gw8bgjccrjxlyzhxa7h85gkxgqwhn-libcufile-1.11.1.6-static\;/nix/store/p9dnsv7mv8mqm9aisrckq8lm3zs3l7dk-cudnn-9.11.0.98-lib\;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev\;/nix/store/dpska4iiya4xa5zzzmqzx3ljws73bnds-cuda_nvml_dev-12.6.77-static\;/nix/store/gzykkbwmch7pxgfzf86fg0b928lz6b36-libcusparse-12.5.4.2-static\;/nix/store/nqn7lvw8gbwbymdhz4nak9wf9b5bbah9-libcublas-12.6.4.1\;/nix/store/4pwy3k2s52ppzbs3k6d58kda8jhmiim4-libcufile-1.11.1.6-dev -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/strip -DCMAKE_RANLIB=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/ranlib -DCMAKE_AR=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/r3gwdvvsgl1csl12f4pkhz0jhsch7bdy-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/include\;/nix/store/5f6h6xs5c74iqcjda3y73i290mfwfs9x-cuda_nvml_dev-12.6.77-dev/include\;/nix/store/r26q9f2lhsvimxha44g1xcck3adrdqwg-cuda_nvrtc-12.6.85-dev/include\;/nix/store/9ik1skjb698l6vkx4m4wvx2nrr4sx0na-libcufft-11.3.0.4-dev/include\;/nix/store/vl1dficb0blxzqg6xqzfi5p119jvl2vi-libcusolver-11.7.1.2-dev/include\;/nix/store/n7x9kkzi2jdfj6f6yjwywfhyfmn957zp-cuda_cupti-12.6.80-dev/include\;/nix/store/sskxmb670akk0avrahrl4r6hp7925zh8-cuda_cudart-12.6.77-dev/include\;/nix/store/8a9vz66yzsar01lpgipmzq8skyk3ymkp-cuda_cccl-12.6.77-dev/include\;/nix/store/xd2xrldv3lbg1bk93nr0yccy6j0vhh2k-cudnn-9.11.0.98-dev/include\;/nix/store/0w4g3rxgkw9r0lv737rslqdk7wldmi0n-libcurand-10.3.7.77-dev/include\;/nix/store/m0s4p867fk6wk8ba7ym9yff4mayqjhlw-libcusparse-12.5.4.2-dev/include\;/nix/store/blh9iyvjkmwd871mfjvfhnp7njwgnc6b-cuda_profiler_api-12.6.77-dev/include\;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev/include\;/nix/store/4pwy3k2s52ppzbs3k6d58kda8jhmiim4-libcufile-1.11.1.6-dev/include -DCUDAToolkit_ROOT=/nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85\;/nix/store/1qgrl2sgdj5m7llm2vs9690gd9998psq-cudnn-9.11.0.98\;/nix/store/d2z15dzsgfm4r2yyl16n3wc0sw8z6fia-cuda_cupti-12.6.80-lib\;/nix/store/86ngm5djfbl6a0i43j282680chqz1vr8-libcusparse-12.5.4.2-lib\;/nix/store/bmph9rbyqnyjs02zriwq78kg16h12wi6-libcublas-12.6.4.1-lib\;/nix/store/wny8xmyma0ziffas96ansxgmjfqpw393-cuda_nvrtc-12.6.85-lib\;/nix/store/j40ndiqjiqbiqrbfmgmkzz6w8757cgvk-cuda_nvml_dev-12.6.77-lib\;/nix/store/3ii532blh586xxavim32i21kr84wlcdc-cuda_profiler_api-12.6.77\;/nix/store/j32l8jnzckhdy2lzxgyd59y7p39y6b1d-libcusolver-11.7.1.2-static\;/nix/store/5iv2zpbf4k00ch4c5zfi5b8dlj90y3d3-cuda_cccl-12.6.77\;/nix/store/a8yi28jqv5185bbv10jpjja3x98i86hm-cuda_cudart-12.6.77-stubs\;/nix/store/ya85qn68jv6mlq6gh6phh5hwk3dkynag-cuda_cudart-12.6.77-static\;/nix/store/m65ribrsnk3gbabcx9ah6phgiil19j01-libcufile-1.11.1.6\;/nix/store/5f6h6xs5c74iqcjda3y73i290mfwfs9x-cuda_nvml_dev-12.6.77-dev\;/nix/store/r26q9f2lhsvimxha44g1xcck3adrdqwg-cuda_nvrtc-12.6.85-dev\;/nix/store/9ik1skjb698l6vkx4m4wvx2nrr4sx0na-libcufft-11.3.0.4-dev\;/nix/store/k5rbpivsz3ilsxg91pgigp6la8ln3cv9-cuda_cupti-12.6.80\;/nix/store/vl1dficb0blxzqg6xqzfi5p119jvl2vi-libcusolver-11.7.1.2-dev\;/nix/store/n7x9kkzi2jdfj6f6yjwywfhyfmn957zp-cuda_cupti-12.6.80-dev\;/nix/store/m0fwdgh4nmrjd0q9v4m2ly63qbcq2hi2-cuda_cudart-12.6.77\;/nix/store/qfaxx4b8l1alrrl0gbyb23k3j850c0v5-libcurand-10.3.7.77-static\;/nix/store/w1npzy8mfl28w7cib5idkg6nvlbzhpzq-libcufile-1.11.1.6-lib\;/nix/store/8abbm2gd77dv0l3acw0s18wln36aa0l5-cuda_cudart-12.6.77-lib\;/nix/store/ykb9bv2lqkf1wzy73q96cb04pybx9xa2-cuda_nvcc-12.6.85-static\;/nix/store/nw9ws2qvhgdb33qgfx4iqj517814qq8y-libcufft-11.3.0.4\;/nix/store/sskxmb670akk0avrahrl4r6hp7925zh8-cuda_cudart-12.6.77-dev\;/nix/store/mfc3ah6lwfd8dfbs77b0z9i75c471b0n-libcufft-11.3.0.4-static\;/nix/store/zk3cg1ws6cskrzyhdr5d68f8zrkfk77d-cuda_nvrtc-12.6.85-static\;/nix/store/pcrirrvn2ya5d3r1y18s2zj4pm2jladw-libcusolver-11.7.1.2\;/nix/store/qdn67x8jrwr418air16kwicya4d747pq-libcufft-11.3.0.4-lib\;/nix/store/dg8hyrzy7sh3wdhcr4ywsz05cvl6vfyc-libcusparse-12.5.4.2\;/nix/store/8a9vz66yzsar01lpgipmzq8skyk3ymkp-cuda_cccl-12.6.77-dev\;/nix/store/wmcrrdxd3db58nklyp7yf90kknfdx6b5-libcurand-10.3.7.77-lib\;/nix/store/xd2xrldv3lbg1bk93nr0yccy6j0vhh2k-cudnn-9.11.0.98-dev\;/nix/store/0w4g3rxgkw9r0lv737rslqdk7wldmi0n-libcurand-10.3.7.77-dev\;/nix/store/jr1397g6pshvil5n4lnvp7dm24dm71h8-libcublas-12.6.4.1-static\;/nix/store/wq0wv7df58h6bgggnz964sk8m1hbkxxp-cuda_cupti-12.6.80-sample\;/nix/store/m0s4p867fk6wk8ba7ym9yff4mayqjhlw-libcusparse-12.5.4.2-dev\;/nix/store/blh9iyvjkmwd871mfjvfhnp7njwgnc6b-cuda_profiler_api-12.6.77-dev\;/nix/store/ngwsphsxf906z7cgwg32d1w83p809ywl-cudnn-9.11.0.98-static\;/nix/store/07zlxn68jyf4s263xafnjid55grmi7a2-cuda_nvrtc-12.6.85\;/nix/store/zyh7hqq402zc7dhafhbh9vycyzcfq256-libcurand-10.3.7.77\;/nix/store/x7mww4k0zzzb7bnffv0b22jqbyf1mg3v-cuda_cupti-12.6.80-static\;/nix/store/xvlapjc6spss1kvbjlq97m6pk19hfrxz-cuda_nvml_dev-12.6.77\;/nix/store/7j4zf0r8flh7l4x5pm1mgqb2vcabmcdj-libcusolver-11.7.1.2-lib\;/nix/store/gs8gw8bgjccrjxlyzhxa7h85gkxgqwhn-libcufile-1.11.1.6-static\;/nix/store/p9dnsv7mv8mqm9aisrckq8lm3zs3l7dk-cudnn-9.11.0.98-lib\;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev\;/nix/store/dpska4iiya4xa5zzzmqzx3ljws73bnds-cuda_nvml_dev-12.6.77-static\;/nix/store/gzykkbwmch7pxgfzf86fg0b928lz6b36-libcusparse-12.5.4.2-static\;/nix/store/nqn7lvw8gbwbymdhz4nak9wf9b5bbah9-libcublas-12.6.4.1\;/nix/store/4pwy3k2s52ppzbs3k6d58kda8jhmiim4-libcufile-1.11.1.6-dev -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/strip -DCMAKE_RANLIB=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ranlib -DCMAKE_AR=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/qal2apcjwlw2p2kk05dwqdgzh8ml687l-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/2dc9bgppqvyd6bd5m4j9zphiyhhd39lv-libcurand-10.3.9.90-dev/include\;/nix/store/x6d389mfn7v413ia2had715g7rdgghgm-cuda_nvrtc-12.8.93-dev/include\;/nix/store/4sz65s9xk80q9jij0i4zbp9xd1pmr3ja-libcusparse-12.5.8.93-dev/include\;/nix/store/11bshw90q985bpd9ds649qmgg0x54q7x-cudnn-9.11.0.98-dev/include\;/nix/store/8dwjdyr7y3dkqlgswpn9swz884lx62gf-cuda_cccl-12.8.90-dev/include\;/nix/store/4cq7zkla3djm6g5gkpzzx4gfikda2k7z-cuda_profiler_api-12.8.90-dev/include\;/nix/store/90nghg4zsrw6gki8y8hw4id3p31bc8rk-libcusolver-11.7.3.90-dev/include\;/nix/store/vg32acb8vlqyhkhabbgvmralfw0kwhi3-cuda_cudart-12.8.90-dev/include\;/nix/store/vqg4r8izl1fy2smmw4dwv4x1adkj0rfb-libcufft-11.3.3.83-dev/include\;/nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/include\;/nix/store/5pvax5f2dg278j43b4llkdxim9y0bjaf-cuda_nvml_dev-12.8.90-dev/include\;/nix/store/mps4gsnyk6s676zadvcykjxn08yghk5a-libcufile-1.13.1.3-dev/include\;/nix/store/gz9xyhflw755r8fcxkc816fp54sj0hl4-cuda_cupti-12.8.90-dev/include\;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev/include -DCUDAToolkit_ROOT=/nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93\;/nix/store/w96jlfiy431jnsww1x3ak3chhssa3i2s-libcusparse-12.5.8.93\;/nix/store/6zj6v3b9v8xdjs94iq1228slqwr757ij-libcublas-12.8.4.1\;/nix/store/q85pndpvaqdznfijmkn0mlfp8y3v08dl-cuda_cccl-12.8.90\;/nix/store/2dc9bgppqvyd6bd5m4j9zphiyhhd39lv-libcurand-10.3.9.90-dev\;/nix/store/cwy7010iwla9b2v1fx82sp66v12r913x-libcublas-12.8.4.1-lib\;/nix/store/x6d389mfn7v413ia2had715g7rdgghgm-cuda_nvrtc-12.8.93-dev\;/nix/store/22n25ss46s0hgspdp26qk025w9m393cd-libcublas-12.8.4.1-static\;/nix/store/sc5wnfvmk0j73xdppxj25kgk8s98lscs-cuda_nvrtc-12.8.93-lib\;/nix/store/54wqrrh6qbrwmv2wkz6b216ljrqbhcji-cudnn-9.11.0.98\;/nix/store/4sz65s9xk80q9jij0i4zbp9xd1pmr3ja-libcusparse-12.5.8.93-dev\;/nix/store/11bshw90q985bpd9ds649qmgg0x54q7x-cudnn-9.11.0.98-dev\;/nix/store/8dwjdyr7y3dkqlgswpn9swz884lx62gf-cuda_cccl-12.8.90-dev\;/nix/store/1v8m3gdw08hnbs7qa4jbkflm9lg1r5q6-libcurand-10.3.9.90\;/nix/store/jc58pv1cxhvpblrnzgaai60x04q6m0bp-cuda_nvml_dev-12.8.90-lib\;/nix/store/khwhv5d4kmzjpsm785iz3sva6i9sj9r5-libcufile-1.13.1.3-static\;/nix/store/xv6c2jcc3adyqks2xl28p4r0q1g4bc92-cuda_cupti-12.8.90\;/nix/store/a2h2yfjfx0si8smnqmghw7ccj0qbnv81-cuda_cupti-12.8.90-lib\;/nix/store/4cq7zkla3djm6g5gkpzzx4gfikda2k7z-cuda_profiler_api-12.8.90-dev\;/nix/store/xccbzbpcn8r506zdvhvbkqkilhlrh3c5-cuda_cudart-12.8.90-lib\;/nix/store/acbir62i1d7kvka4plmxsq8442z7r1l2-cuda_cudart-12.8.90-stubs\;/nix/store/ckkcbggf4x93zg3xn9xr00jgxs2x5p21-cuda_nvml_dev-12.8.90-static\;/nix/store/ml3bkm8bz1lnjmfd8lyxbjqpi1llasr2-libcusolver-11.7.3.90\;/nix/store/9zlrjnq7lisarny3llszk131vy816x2w-libcufile-1.13.1.3\;/nix/store/90nghg4zsrw6gki8y8hw4id3p31bc8rk-libcusolver-11.7.3.90-dev\;/nix/store/vg32acb8vlqyhkhabbgvmralfw0kwhi3-cuda_cudart-12.8.90-dev\;/nix/store/y27d2s3rcw8d17wcw23glhlj5rhs8d6y-cuda_cudart-12.8.90\;/nix/store/n96pib9yj31n031dmrrx43m61js1r5rn-cuda_nvcc-12.8.93-static\;/nix/store/pabakly3280dnghh3i89wklfm61raf7z-cuda_cupti-12.8.90-sample\;/nix/store/l0jiwp1f0dhigd41qqf408c5qyabz2vd-cudnn-9.11.0.98-static\;/nix/store/95lzbxp68m127n6hyllbr3dh2mlj7y8m-libcufft-11.3.3.83\;/nix/store/lxsd5l6hnqcfgqc1nsn8mmmpx385m3k8-libcusparse-12.5.8.93-lib\;/nix/store/vqg4r8izl1fy2smmw4dwv4x1adkj0rfb-libcufft-11.3.3.83-dev\;/nix/store/4b9rdinnksj1856siw3qmwi9f10480ii-cuda_nvrtc-12.8.93-static\;/nix/store/qh7zggir1ikzh3kvkhi2mqzpyisl4153-libcurand-10.3.9.90-static\;/nix/store/n25l4gcpw8cry4rg2a4c9jw3f53i65zd-libcusolver-11.7.3.90-lib\;/nix/store/xh73kc8spwfvd6w6wc63pyq3zm6qlrja-cuda_nvml_dev-12.8.90\;/nix/store/bgiqy1z8588hgcdzyh9brhc015w3nii0-libcurand-10.3.9.90-lib\;/nix/store/5pvax5f2dg278j43b4llkdxim9y0bjaf-cuda_nvml_dev-12.8.90-dev\;/nix/store/7lf23alvk7yh64flf2mj6smx66sqyz9d-libcufile-1.13.1.3-lib\;/nix/store/lfqj2ni7r0ir3n840b8r1lh63mnqr0ar-libcusparse-12.5.8.93-static\;/nix/store/qmw5pq21avnfvsk657k0zr4nsgwxa4jm-cuda_cudart-12.8.90-static\;/nix/store/826d39r2b4gwafqsyhvzq2bmqv8ygzrd-cuda_profiler_api-12.8.90\;/nix/store/g52lygjflrsyr6wahpf0rvs3fpna3wq9-cudnn-9.11.0.98-lib\;/nix/store/gxw5c9f7q2f1pmy0g1zyblb8p2p891a4-libcufft-11.3.3.83-lib\;/nix/store/pbsi8w1in7q44z83ndqsaxyzfrr2frgh-cuda_nvrtc-12.8.93\;/nix/store/mps4gsnyk6s676zadvcykjxn08yghk5a-libcufile-1.13.1.3-dev\;/nix/store/mvfnbb1m20fkv2n0j69ky9s9afn8p7h1-libcufft-11.3.3.83-static\;/nix/store/8byjxgnvhcyav2283wcxp752d8280c36-libcusolver-11.7.3.90-static\;/nix/store/gz9xyhflw755r8fcxkc816fp54sj0hl4-cuda_cupti-12.8.90-dev\;/nix/store/jyd8jp3q1d408n8842rb8g6ziviwm7q1-cuda_cupti-12.8.90-static\;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/strip -DCMAKE_RANLIB=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ranlib -DCMAKE_AR=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/aikr517kmcd8r2nrrj70jq71d7352qiq-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/kky5wd8qwb0hx3jb3j9qc1bkwznw3z83-libcusparse-12.5.10.65-dev/include\;/nix/store/dd8wl3nnsigw2gj5bwaiswla97jpw1jz-libcublas-12.9.1.4-dev/include\;/nix/store/zsmc0yjbjrfbamm9ycrlz5yzi5hrbag1-libcurand-10.3.10.19-dev/include\;/nix/store/ip4lb9ximc445dbdkdvia4whx83g00g3-libcusolver-11.7.5.82-dev/include\;/nix/store/81xppf0rrqfasvg7wy4z891ab473nb9v-libcufile-1.14.1.1-dev/include\;/nix/store/nkvyh0qxbfj2wbm3r800xd6x1fhs1s4x-cuda_cccl-12.9.27-dev/include\;/nix/store/ik96pdimvw3bjj8wdr6laxycnn5lpwby-libcufft-11.4.1.4-dev/include\;/nix/store/f9r19xpj8qayy3b74gx3gbjrq0z1aq3b-cuda_nvml_dev-12.9.79-dev/include\;/nix/store/0kycn0pb0x46h16afxw2bjrm1gjq1355-cuda_profiler_api-12.9.79-dev/include\;/nix/store/z2xfln4d3r92hjjihlq5w6hvh5qhpcb4-cudnn-9.11.0.98-dev/include\;/nix/store/x4w41r4jyapqwdghvi6xrpd0mnim4x08-cuda_cudart-12.9.79-dev/include\;/nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86/include\;/nix/store/f21f8hghg4fiwa2ix29h1zy854p7q4v6-cuda_nvrtc-12.9.86-dev/include\;/nix/store/ns0brisbkgrjyfi16rlyjjgcym4jk6qv-cuda_cupti-12.9.79-dev/include -DCUDAToolkit_ROOT=/nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86\;/nix/store/q2al0drhrl0yxk97xbsjl8d0h25kmsq9-libcurand-10.3.10.19-lib\;/nix/store/ax1ssn45048qbmyy19basgv6q64y5jy0-cuda_cupti-12.9.79\;/nix/store/m09542l6q83flp3asv2r4j3wcbjqksvg-libcufile-1.14.1.1-static\;/nix/store/b3wbcra9cziq8bwf3yhmj2nn1mf5bqy2-cuda_cudart-12.9.79-lib\;/nix/store/j5kp5fg9mn6hhslk18wbmskc7v96l353-cuda_cupti-12.9.79-static\;/nix/store/kky5wd8qwb0hx3jb3j9qc1bkwznw3z83-libcusparse-12.5.10.65-dev\;/nix/store/dd8wl3nnsigw2gj5bwaiswla97jpw1jz-libcublas-12.9.1.4-dev\;/nix/store/zsmc0yjbjrfbamm9ycrlz5yzi5hrbag1-libcurand-10.3.10.19-dev\;/nix/store/3s79bz4ldkhlks6jf9a2jd4r34y6018b-libcurand-10.3.10.19\;/nix/store/v48xzq66pzmygxqkws17n9nvpa7lad9d-cuda_nvml_dev-12.9.79\;/nix/store/6via2axi1n31n685jii6dwaiqca8b2rc-cuda_nvcc-12.9.86-static\;/nix/store/v0hx9fqdlmz9kvjd9sqr2zc141ny10yn-cuda_profiler_api-12.9.79\;/nix/store/ip4lb9ximc445dbdkdvia4whx83g00g3-libcusolver-11.7.5.82-dev\;/nix/store/8cig7k11qv5g8x0j8n2mbdfzwrnf7cg2-cuda_cudart-12.9.79-stubs\;/nix/store/xg8pj5m74n2h3v8kgxbvmbpcl90rzmlx-cudnn-9.11.0.98-static\;/nix/store/v4b7mkhyq1akczzkcyynj7y9c61l9dc7-cuda_cudart-12.9.79-static\;/nix/store/hw2swakbrvi4innrymcw8i2m98p73br0-cuda_cupti-12.9.79-sample\;/nix/store/s1i2kadnni2m4skpzzqzfzc3bpmrxi7p-libcusparse-12.5.10.65-lib\;/nix/store/81xppf0rrqfasvg7wy4z891ab473nb9v-libcufile-1.14.1.1-dev\;/nix/store/0a83zdhkh2i9d97r4zqdn8fi8vn4wfk3-libcublas-12.9.1.4-static\;/nix/store/nkvyh0qxbfj2wbm3r800xd6x1fhs1s4x-cuda_cccl-12.9.27-dev\;/nix/store/jnhjz87sm9nbnb72n54jj2l99szrzpg2-libcusparse-12.5.10.65\;/nix/store/ik96pdimvw3bjj8wdr6laxycnn5lpwby-libcufft-11.4.1.4-dev\;/nix/store/d1m6c5i6y6ncjygpdmv1b4pmd91hvjr2-cuda_cupti-12.9.79-lib\;/nix/store/49p6af3v11dcxvq9andr6l8csa2sr4j4-cuda_nvrtc-12.9.86-static\;/nix/store/bfygrgghga26l7br5d5j3h6hd1s21rkn-cudnn-9.11.0.98\;/nix/store/a6an9chi5dvjsybrfrxql0bn76xswzpa-libcufft-11.4.1.4\;/nix/store/f9r19xpj8qayy3b74gx3gbjrq0z1aq3b-cuda_nvml_dev-12.9.79-dev\;/nix/store/7zy91byrxpnyzhjlwham2gqyir2x6f54-libcusolver-11.7.5.82-lib\;/nix/store/0kycn0pb0x46h16afxw2bjrm1gjq1355-cuda_profiler_api-12.9.79-dev\;/nix/store/cx0hyla7fkqqc5hh1gn4hkarjyjvbjhf-libcusparse-12.5.10.65-static\;/nix/store/3yi8kx62nklnyn77zn4z23hi03l9c7ff-libcusolver-11.7.5.82-static\;/nix/store/z2xfln4d3r92hjjihlq5w6hvh5qhpcb4-cudnn-9.11.0.98-dev\;/nix/store/86nq76ks8vlgjdsnh1hkskyfw7mm3plc-cuda_cccl-12.9.27\;/nix/store/01ywykdxfkvp64318anifgx7zaavz9ql-cuda_nvml_dev-12.9.79-lib\;/nix/store/qv2m9i0nby2p03xx37mkkm84dlqb9s84-cuda_cudart-12.9.79\;/nix/store/a09saq5rl5jxbgv9gqllx0080ypjk00x-libcufile-1.14.1.1-lib\;/nix/store/0l18n4dhavr0p4rk0nyqqjr8paacak13-libcufile-1.14.1.1\;/nix/store/r8ly0w88qv4gw3lhd784ha0ag221c23s-cuda_nvrtc-12.9.86-lib\;/nix/store/rngn6cls1blhilrw78xb3pjgwghibhzk-libcurand-10.3.10.19-static\;/nix/store/x4w41r4jyapqwdghvi6xrpd0mnim4x08-cuda_cudart-12.9.79-dev\;/nix/store/ikw7sqic4kknjkp50dr54khgs06q1hbv-cuda_nvml_dev-12.9.79-static\;/nix/store/bzdnjn29xj8a73wg16qrz0sswi9svp0x-libcublas-12.9.1.4\;/nix/store/62hqkwasnanq5i1j63z4clc0s4c61k1r-libcufft-11.4.1.4-static\;/nix/store/5sjldyn2vmm4ky24v1f9ggs0hps496q3-libcusolver-11.7.5.82\;/nix/store/9c924z3749bfm078bwq4ad12kjz46pjf-libcufft-11.4.1.4-lib\;/nix/store/f21f8hghg4fiwa2ix29h1zy854p7q4v6-cuda_nvrtc-12.9.86-dev\;/nix/store/c1kdvq8xqqkwzzazl99w20h4x9z0f9pc-libcublas-12.9.1.4-lib\;/nix/store/ns0brisbkgrjyfi16rlyjjgcym4jk6qv-cuda_cupti-12.9.79-dev\;/nix/store/h6kzw3gvlv4sa0apb4fflpjlirhj72ga-cudnn-9.11.0.98-lib\;/nix/store/f5gvpjis5y727lw6vzr2h1zkb3hm08k2-cuda_nvrtc-12.9.86 -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/hdw3ksc9knwilc0sc7bnzhilimcbsddm-gcc-wrapper-11.5.0/bin/strip -DCMAKE_RANLIB=/nix/store/hdw3ksc9knwilc0sc7bnzhilimcbsddm-gcc-wrapper-11.5.0/bin/ranlib -DCMAKE_AR=/nix/store/hdw3ksc9knwilc0sc7bnzhilimcbsddm-gcc-wrapper-11.5.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/cmr2xc65fzb6kv5ifyhiyvzs5lay03z5-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/hdw3ksc9knwilc0sc7bnzhilimcbsddm-gcc-wrapper-11.5.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/gjp3yjc2c9n1iphj27ndbxz4n27c2p1p-libcusolver-11.4.1.48-dev/include\;/nix/store/vrfn7hnjmyxmq636wwnd0mwxs1iz8bvr-cudnn-9.8.0.87-dev/include\;/nix/store/6yx59h9krl2wyl3l9xs0pg9kxxixdcwm-libcurand-10.3.0.86-dev/include\;/nix/store/60il6j9pd01rj162b3ajz86himmzbl2n-cuda_cccl-11.8.89-dev/include\;/nix/store/bw6ppn3fjw8vvyhnch11334qa7ha0j82-libcufft-10.9.0.58-dev/include\;/nix/store/g41qrp3nxgsf0j3y925aqb9yhvx57bgs-libcufile-1.4.0.31-dev/include\;/nix/store/as7abq4l2vpmgqrkxhs02ypd7swd95rv-libcublas-11.11.3.6-dev/include\;/nix/store/k6acl4f3bapawgxsk0s8zbh8rwz4ff9z-cuda_nvml_dev-11.8.86-dev/include\;/nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89/include\;/nix/store/4w3l210dy06b72mjg6ahcwla8dvp3s29-cuda_nvrtc-11.8.89-dev/include\;/nix/store/a0jvdsx38l64ibkcj6zklinz22xzm84i-cuda_cupti-11.8.87-dev/include\;/nix/store/zp3jfl220q91778z0lbbi0kwhdsx43mh-cuda_nvtx-11.8.86-dev/include\;/nix/store/i3mdycl6w6zxyn7qhng2bqrjbn4zvqfq-cuda_profiler_api-11.8.86-dev/include\;/nix/store/fd75174gl2zdqfm4bl7dnm0qbydrmzi5-cuda_cudart-11.8.89-dev/include\;/nix/store/2iqd17dwfr5vp3z5a3bhab2gqq4lhhd2-libcusparse-11.7.5.86-dev/include -DCUDAToolkit_ROOT=/nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89\;/nix/store/0ac7a8nvkwjhibz80p4zkls1bzqb0y2g-libcufft-10.9.0.58-lib\;/nix/store/81diiz22kjbj7q74w12g6mqlj5y54w5c-cuda_nvml_dev-11.8.86\;/nix/store/pxd9h8z8qz9g0x98dafc3g4f2r6rnwg8-libcufile-1.4.0.31\;/nix/store/w5d6bfkfs3cl7w2fi953sphjipgn96j2-cuda_nvrtc-11.8.89-lib\;/nix/store/gjp3yjc2c9n1iphj27ndbxz4n27c2p1p-libcusolver-11.4.1.48-dev\;/nix/store/sh18pzickwhknyhy5dqrqnapwfbh1q34-libcufile-1.4.0.31-lib\;/nix/store/d0h94ramphqbn3hz8hwwws1cngh6z2m4-libcusolver-11.4.1.48-lib\;/nix/store/nmlrk730kllm7x24w861rbhyklv228x5-libcusparse-11.7.5.86\;/nix/store/j3gfm5pqpvyq6wk8j1iswqy56vbhybwj-libcusparse-11.7.5.86-static\;/nix/store/vrfn7hnjmyxmq636wwnd0mwxs1iz8bvr-cudnn-9.8.0.87-dev\;/nix/store/6yx59h9krl2wyl3l9xs0pg9kxxixdcwm-libcurand-10.3.0.86-dev\;/nix/store/vfzb4kby50x5cv65rhgayx851xlh8vqs-cuda_nvcc-11.8.89-static\;/nix/store/zbq3x8ms1rabyv8sq6n1918xvyfzl8pq-cuda_cudart-11.8.89\;/nix/store/yh4lr1pa88hpi1ljfrkm6ygccmsjjraj-cuda_cudart-11.8.89-static\;/nix/store/xsmpfb4xw185q61nm14dvrkbp08f64jr-libcurand-10.3.0.86-static\;/nix/store/p6ywk8npbrsj7yfh06kn9hgp81wjan9h-libcusolver-11.4.1.48-static\;/nix/store/z0913w52zx2fzxw5gj4izkh5msghr31z-libcufft-10.9.0.58-static\;/nix/store/i16adfknmb7d99m80hp1l2c6r69y15j6-cuda_nvrtc-11.8.89\;/nix/store/y9wwdk6lfdwafsjdbf08k2d3kriznvl5-libcufile-1.4.0.31-static\;/nix/store/jqjwqh18vvarbifcgaimgjqpw7b0c0xa-cuda_profiler_api-11.8.86\;/nix/store/60il6j9pd01rj162b3ajz86himmzbl2n-cuda_cccl-11.8.89-dev\;/nix/store/bw6ppn3fjw8vvyhnch11334qa7ha0j82-libcufft-10.9.0.58-dev\;/nix/store/g41qrp3nxgsf0j3y925aqb9yhvx57bgs-libcufile-1.4.0.31-dev\;/nix/store/wvk7nw0xsvrr7i2n2s2lbb1cxlx031xl-cudnn-9.8.0.87-lib\;/nix/store/6jcsnhxd7n5c4g29cq5ysyswvckmf780-cuda_nvrtc-11.8.89-static\;/nix/store/as7abq4l2vpmgqrkxhs02ypd7swd95rv-libcublas-11.11.3.6-dev\;/nix/store/jaz051ajib9yqc2ql0w8ysxin6gvry63-cuda_cudart-11.8.89-lib\;/nix/store/92if92f0ys364hnlx382p2a5229ibwg4-cuda_cupti-11.8.87\;/nix/store/1l08kr5ip1jyfkcf9fpn9mxhnhssvhdz-cuda_nvtx-11.8.86-lib\;/nix/store/cf0srs9y3zh5hpgz0s9fbb83s1mvb9pv-libcublas-11.11.3.6-static\;/nix/store/mq4jx3r0iwwlp6i0knzsn9s4kriil1gq-cudnn-9.8.0.87-static\;/nix/store/m9w4smq9ypawc8m1rdic65vpjahyzsd4-libcublas-11.11.3.6\;/nix/store/qc9x1wij4anq3rsvzf83ssfzrkrrrp9v-libcusolver-11.4.1.48\;/nix/store/k6acl4f3bapawgxsk0s8zbh8rwz4ff9z-cuda_nvml_dev-11.8.86-dev\;/nix/store/v6nam0bvwyfnl5q7svm4k5b6b1s4jrka-cuda_nvml_dev-11.8.86-lib\;/nix/store/4w3l210dy06b72mjg6ahcwla8dvp3s29-cuda_nvrtc-11.8.89-dev\;/nix/store/izz465hqqa4hvk0073yqja7w7nhyhhqi-cudnn-9.8.0.87\;/nix/store/2cka8a1cl63ilihil5p05hbp4iqm05q6-cuda_cupti-11.8.87-sample\;/nix/store/a0jvdsx38l64ibkcj6zklinz22xzm84i-cuda_cupti-11.8.87-dev\;/nix/store/s5gcqs33h7y3xcq86bfz57rkiapnkkqv-libcufft-10.9.0.58\;/nix/store/7j5vljppl7mnyhysfcq6sabpiv0i84sb-libcusparse-11.7.5.86-lib\;/nix/store/62kg9m9aih57vci8r9qziyxpra7p5zln-libcurand-10.3.0.86-lib\;/nix/store/lw74qqzaksb8yy2mb576w5wdv1blbkyj-libcurand-10.3.0.86\;/nix/store/s3xyp6khw00lrrb0snlyrhvpq6mivc1r-cuda_cupti-11.8.87-lib\;/nix/store/qahdyi94lsj3v9k91gc02iv579gxp9a9-libcublas-11.11.3.6-lib\;/nix/store/si6xyjc32vrnrbrxb63s1bvhwvc087ab-cuda_cccl-11.8.89\;/nix/store/4i63ypmb0k44d6vvzg1gwficzmx0z5mq-cuda_nvtx-11.8.86\;/nix/store/gxyyzhr7p1wb2h1fmxr5jxspdl97139z-libcufile-1.4.0.31-sample\;/nix/store/53ycm00ckf5bgjlywf6f7x6d021cxj8r-cuda_cudart-11.8.89-stubs\;/nix/store/zp3jfl220q91778z0lbbi0kwhdsx43mh-cuda_nvtx-11.8.86-dev\;/nix/store/i3mdycl6w6zxyn7qhng2bqrjbn4zvqfq-cuda_profiler_api-11.8.86-dev\;/nix/store/7xakk7z1700acb1bx5lrkm1wlhrryng5-cuda_cupti-11.8.87-static\;/nix/store/fd75174gl2zdqfm4bl7dnm0qbydrmzi5-cuda_cudart-11.8.89-dev\;/nix/store/2iqd17dwfr5vp3z5a3bhab2gqq4lhhd2-libcusparse-11.7.5.86-dev -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> Running phase: unpackPhase +torch_harmonics_attn-torch-ext> unpacking source archive /nix/store/nzsl2gnl959ddqrnkyrvsybi390dik87-source +torch_harmonics_attn-torch-ext> source root is source +torch_harmonics_attn-torch-ext> Running phase: patchPhase +torch_harmonics_attn-torch-ext> Running phase: updateAutotoolsGnuConfigScriptsPhase +torch_harmonics_attn-torch-ext> Running phase: configurePhase +torch_harmonics_attn-torch-ext> Executing setupCUDAToolkitCompilers +torch_harmonics_attn-torch-ext> fixing cmake files... +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 13.4.0 +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 13.4.0 +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 14.3.0 +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 14.3.0 +torch_harmonics_attn-torch-ext> cmake flags: -GNinja -DCMAKE_FIND_USE_SYSTEM_PACKAGE_REGISTRY=OFF -DCMAKE_FIND_USE_PACKAGE_REGISTRY=OFF -DCMAKE_EXPORT_NO_PACKAGE_REGISTRY=ON -DCMAKE_BUILD_TYPE=Release -DBUILD_TESTING=OFF -DCMAKE_INSTALL_LOCALEDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/share/locale -DCMAKE_INSTALL_LIBEXECDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/libexec -DCMAKE_INSTALL_LIBDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/lib -DCMAKE_INSTALL_DOCDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/share/doc/torch_harmonics_attn -DCMAKE_INSTALL_INFODIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/share/info -DCMAKE_INSTALL_MANDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/share/man -DCMAKE_INSTALL_INCLUDEDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/include -DCMAKE_INSTALL_SBINDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/sbin -DCMAKE_INSTALL_BINDIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/bin -DCMAKE_INSTALL_NAME_DIR=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/lib -DCMAKE_POLICY_DEFAULT_CMP0025=NEW -DCMAKE_FIND_FRAMEWORK=LAST -DCMAKE_STRIP=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/strip -DCMAKE_RANLIB=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ranlib -DCMAKE_AR=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/ar -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_INSTALL_PREFIX=/nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext -DPython_EXECUTABLE:STRING=/nix/store/wirj6dihrpcch7idfd7jy4l0hqfsgkk1-python3-3.13.6-env/bin/python -DCMAKE_CUDA_HOST_COMPILER:STRING=/nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ -DNVCC_THREADS=5 -DCUDAToolkit_INCLUDE_DIR=/nix/store/2dc9bgppqvyd6bd5m4j9zphiyhhd39lv-libcurand-10.3.9.90-dev/include\;/nix/store/x6d389mfn7v413ia2had715g7rdgghgm-cuda_nvrtc-12.8.93-dev/include\;/nix/store/4sz65s9xk80q9jij0i4zbp9xd1pmr3ja-libcusparse-12.5.8.93-dev/include\;/nix/store/11bshw90q985bpd9ds649qmgg0x54q7x-cudnn-9.11.0.98-dev/include\;/nix/store/8dwjdyr7y3dkqlgswpn9swz884lx62gf-cuda_cccl-12.8.90-dev/include\;/nix/store/4cq7zkla3djm6g5gkpzzx4gfikda2k7z-cuda_profiler_api-12.8.90-dev/include\;/nix/store/90nghg4zsrw6gki8y8hw4id3p31bc8rk-libcusolver-11.7.3.90-dev/include\;/nix/store/vg32acb8vlqyhkhabbgvmralfw0kwhi3-cuda_cudart-12.8.90-dev/include\;/nix/store/vqg4r8izl1fy2smmw4dwv4x1adkj0rfb-libcufft-11.3.3.83-dev/include\;/nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/include\;/nix/store/5pvax5f2dg278j43b4llkdxim9y0bjaf-cuda_nvml_dev-12.8.90-dev/include\;/nix/store/klis291y7cza60yzgkxzbid80bnyshmr-cuda_nvtx-12.8.90-dev/include\;/nix/store/mps4gsnyk6s676zadvcykjxn08yghk5a-libcufile-1.13.1.3-dev/include\;/nix/store/gz9xyhflw755r8fcxkc816fp54sj0hl4-cuda_cupti-12.8.90-dev/include\;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev/include -DCUDAToolkit_ROOT=/nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93\;/nix/store/w96jlfiy431jnsww1x3ak3chhssa3i2s-libcusparse-12.5.8.93\;/nix/store/6zj6v3b9v8xdjs94iq1228slqwr757ij-libcublas-12.8.4.1\;/nix/store/q85pndpvaqdznfijmkn0mlfp8y3v08dl-cuda_cccl-12.8.90\;/nix/store/2dc9bgppqvyd6bd5m4j9zphiyhhd39lv-libcurand-10.3.9.90-dev\;/nix/store/cwy7010iwla9b2v1fx82sp66v12r913x-libcublas-12.8.4.1-lib\;/nix/store/x6d389mfn7v413ia2had715g7rdgghgm-cuda_nvrtc-12.8.93-dev\;/nix/store/22n25ss46s0hgspdp26qk025w9m393cd-libcublas-12.8.4.1-static\;/nix/store/sc5wnfvmk0j73xdppxj25kgk8s98lscs-cuda_nvrtc-12.8.93-lib\;/nix/store/54wqrrh6qbrwmv2wkz6b216ljrqbhcji-cudnn-9.11.0.98\;/nix/store/4sz65s9xk80q9jij0i4zbp9xd1pmr3ja-libcusparse-12.5.8.93-dev\;/nix/store/11bshw90q985bpd9ds649qmgg0x54q7x-cudnn-9.11.0.98-dev\;/nix/store/8dwjdyr7y3dkqlgswpn9swz884lx62gf-cuda_cccl-12.8.90-dev\;/nix/store/1v8m3gdw08hnbs7qa4jbkflm9lg1r5q6-libcurand-10.3.9.90\;/nix/store/jc58pv1cxhvpblrnzgaai60x04q6m0bp-cuda_nvml_dev-12.8.90-lib\;/nix/store/khwhv5d4kmzjpsm785iz3sva6i9sj9r5-libcufile-1.13.1.3-static\;/nix/store/xv6c2jcc3adyqks2xl28p4r0q1g4bc92-cuda_cupti-12.8.90\;/nix/store/a2h2yfjfx0si8smnqmghw7ccj0qbnv81-cuda_cupti-12.8.90-lib\;/nix/store/4cq7zkla3djm6g5gkpzzx4gfikda2k7z-cuda_profiler_api-12.8.90-dev\;/nix/store/5f6dvklv5d0mvygrrf0vzp0smcn7kk01-cuda_nvtx-12.8.90\;/nix/store/xccbzbpcn8r506zdvhvbkqkilhlrh3c5-cuda_cudart-12.8.90-lib\;/nix/store/acbir62i1d7kvka4plmxsq8442z7r1l2-cuda_cudart-12.8.90-stubs\;/nix/store/ckkcbggf4x93zg3xn9xr00jgxs2x5p21-cuda_nvml_dev-12.8.90-static\;/nix/store/ml3bkm8bz1lnjmfd8lyxbjqpi1llasr2-libcusolver-11.7.3.90\;/nix/store/9zlrjnq7lisarny3llszk131vy816x2w-libcufile-1.13.1.3\;/nix/store/90nghg4zsrw6gki8y8hw4id3p31bc8rk-libcusolver-11.7.3.90-dev\;/nix/store/vg32acb8vlqyhkhabbgvmralfw0kwhi3-cuda_cudart-12.8.90-dev\;/nix/store/y27d2s3rcw8d17wcw23glhlj5rhs8d6y-cuda_cudart-12.8.90\;/nix/store/wa9pr3485k3mw8jhv7i9kfzjrqmdl5bb-cuda_nvtx-12.8.90-lib\;/nix/store/n96pib9yj31n031dmrrx43m61js1r5rn-cuda_nvcc-12.8.93-static\;/nix/store/pabakly3280dnghh3i89wklfm61raf7z-cuda_cupti-12.8.90-sample\;/nix/store/l0jiwp1f0dhigd41qqf408c5qyabz2vd-cudnn-9.11.0.98-static\;/nix/store/95lzbxp68m127n6hyllbr3dh2mlj7y8m-libcufft-11.3.3.83\;/nix/store/lxsd5l6hnqcfgqc1nsn8mmmpx385m3k8-libcusparse-12.5.8.93-lib\;/nix/store/vqg4r8izl1fy2smmw4dwv4x1adkj0rfb-libcufft-11.3.3.83-dev\;/nix/store/4b9rdinnksj1856siw3qmwi9f10480ii-cuda_nvrtc-12.8.93-static\;/nix/store/qh7zggir1ikzh3kvkhi2mqzpyisl4153-libcurand-10.3.9.90-static\;/nix/store/n25l4gcpw8cry4rg2a4c9jw3f53i65zd-libcusolver-11.7.3.90-lib\;/nix/store/xh73kc8spwfvd6w6wc63pyq3zm6qlrja-cuda_nvml_dev-12.8.90\;/nix/store/bgiqy1z8588hgcdzyh9brhc015w3nii0-libcurand-10.3.9.90-lib\;/nix/store/5pvax5f2dg278j43b4llkdxim9y0bjaf-cuda_nvml_dev-12.8.90-dev\;/nix/store/7lf23alvk7yh64flf2mj6smx66sqyz9d-libcufile-1.13.1.3-lib\;/nix/store/klis291y7cza60yzgkxzbid80bnyshmr-cuda_nvtx-12.8.90-dev\;/nix/store/lfqj2ni7r0ir3n840b8r1lh63mnqr0ar-libcusparse-12.5.8.93-static\;/nix/store/qmw5pq21avnfvsk657k0zr4nsgwxa4jm-cuda_cudart-12.8.90-static\;/nix/store/826d39r2b4gwafqsyhvzq2bmqv8ygzrd-cuda_profiler_api-12.8.90\;/nix/store/g52lygjflrsyr6wahpf0rvs3fpna3wq9-cudnn-9.11.0.98-lib\;/nix/store/gxw5c9f7q2f1pmy0g1zyblb8p2p891a4-libcufft-11.3.3.83-lib\;/nix/store/pbsi8w1in7q44z83ndqsaxyzfrr2frgh-cuda_nvrtc-12.8.93\;/nix/store/mps4gsnyk6s676zadvcykjxn08yghk5a-libcufile-1.13.1.3-dev\;/nix/store/mvfnbb1m20fkv2n0j69ky9s9afn8p7h1-libcufft-11.3.3.83-static\;/nix/store/8byjxgnvhcyav2283wcxp752d8280c36-libcusolver-11.7.3.90-static\;/nix/store/gz9xyhflw755r8fcxkc816fp54sj0hl4-cuda_cupti-12.8.90-dev\;/nix/store/jyd8jp3q1d408n8842rb8g6ziviwm7q1-cuda_cupti-12.8.90-static\;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev -DPROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DProtobuf_PROTOC_EXECUTABLE=/nix/store/g82m0ia59azh4a1bcrk0r15qck6hp8da-protobuf-31.1/bin/protoc -DPYBIND11_PYTHONLIBS_OVERWRITE=OFF -DPYTHON_EXECUTABLE=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/bin/python3.13 -DPYTHON_INCLUDE_DIR=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/include/python3.13 -DPYTHON_SITE_PACKAGES=/nix/store/iyff8129iampdw13nlfqalzhxy8y1hi9-python3-3.13.6/lib/python3.13/site-packages +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 11.5.0 +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- The CXX compiler identification is GNU 14.3.0 +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/rgfv9lch0b6ksjzlzsx0mljsb0ypqr8x-gcc-wrapper-13.4.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/hdw3ksc9knwilc0sc7bnzhilimcbsddm-gcc-wrapper-11.5.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- Detecting CXX compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Check for working CXX compiler: /nix/store/d8likaw8xxdmh2qmmasbm88h74q6a2gr-gcc-wrapper-14.3.0/bin/g++ - skipped +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features +torch_harmonics_attn-torch-ext> -- Detecting CXX compile features - done +torch_harmonics_attn-torch-ext> -- FetchContent base directory: /build/source/build/_deps +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found Python3: /nix/store/g9kwcg39as2saqv417rx3f55662ff8r9-python3-3.13.6-env/bin/python3.13 (found version "3.13.6") found components: Development Development.SABIModule Interpreter Development.Module Development.Embed +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85 (found version "12.6") +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93 (found version "12.8") +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86 (found version "12.9") +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85 (found version "12.6") +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89 (found version "11.8") +torch_harmonics_attn-torch-ext> -- Found CUDA: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93 (found version "12.8") +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 12.6.85 with host compiler GNU 13.4.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 12.8.93 with host compiler GNU 14.3.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 11.8.89 with host compiler GNU 11.5.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 12.9.86 with host compiler GNU 14.3.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 12.6.85 with host compiler GNU 13.4.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- The CUDA compiler identification is NVIDIA 12.8.93 with host compiler GNU 14.3.0 +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/include;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev/include (found version "12.6.85") +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89/include;/nix/store/as7abq4l2vpmgqrkxhs02ypd7swd95rv-libcublas-11.11.3.6-dev/include (found version "11.8.89") +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/include;/nix/store/fy71fffqbwg3xgvygn66kd4igj65gblv-libcublas-12.6.4.1-dev/include (found version "12.6.85") +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/include;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev/include (found version "12.8.93") +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86/include;/nix/store/dd8wl3nnsigw2gj5bwaiswla97jpw1jz-libcublas-12.9.1.4-dev/include (found version "12.9.86") +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Detecting CUDA compiler ABI info - done +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Check for working CUDA compiler: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/bin/nvcc - skipped +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features +torch_harmonics_attn-torch-ext> -- Detecting CUDA compile features - done +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- Found CUDAToolkit: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/include;/nix/store/qa4d2v0lsm6giyr4b4421qsdygz0yrrh-libcublas-12.8.4.1-dev/include (found version "12.8.93") +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthreads - not found +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread +torch_harmonics_attn-torch-ext> -- Looking for pthread_create in pthread - found +torch_harmonics_attn-torch-ext> -- Found Threads: TRUE +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 12.6 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 11.8 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/416hmc0rmh4k2ynvnq540mvh6xc0lk0f-cuda_nvcc-11.8.89 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 12.6 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/7iw4ipbdy17yzvqjhxpw03i17kq7f7rj-cuda_nvcc-12.6.85 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 12.8 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 12.9 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/8zrv6h6f2cfz34pwq012n4cx2zrv5m1s-cuda_nvcc-12.9.86 +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 12.6 +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 11.8 +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 12.6 +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 12.8 +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 12.9 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/r3gwdvvsgl1csl12f4pkhz0jhsch7bdy-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ld6fk094jhhsnbip1406vrky9lmyxbax-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ld6fk094jhhsnbip1406vrky9lmyxbax-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/ld6fk094jhhsnbip1406vrky9lmyxbax-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/cmr2xc65fzb6kv5ifyhiyvzs5lay03z5-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> CMake Warning (dev) at /nix/store/0vnarm4qjnj16dr3zj9kwq6bn79c0icn-cmake-3.31.7/share/cmake-3.31/Modules/FindPackageHandleStandardArgs.cmake:441 (message): +torch_harmonics_attn-torch-ext> The package name passed to `find_package_handle_standard_args` (nvtx3) does +torch_harmonics_attn-torch-ext> not match the name of the calling package (Caffe2). This can lead to +torch_harmonics_attn-torch-ext> problems in calling code that expects `find_package` result variables +torch_harmonics_attn-torch-ext> (e.g., `_FOUND`) to follow a certain pattern. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:184 (find_package_handle_standard_args) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> This warning is for project developers. Use -Wno-dev to suppress it. +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Could NOT find nvtx3 (missing: nvtx3_dir) +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:190 (message): +torch_harmonics_attn-torch-ext> Cannot find NVTX3, find old NVTX instead +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/utils.cmake:328 (message): +torch_harmonics_attn-torch-ext> In the future we will require one to explicitly pass TORCH_CUDA_ARCH_LIST +torch_harmonics_attn-torch-ext> to cmake instead of implicitly setting it as an env variable. This will +torch_harmonics_attn-torch-ext> become a FATAL_ERROR in future version of pytorch. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:337 (torch_cuda_get_nvcc_gencode_flag) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA detected: 12.8 +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA nvcc is: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93/bin/nvcc +torch_harmonics_attn-torch-ext> -- PyTorch: CUDA toolkit directory: /nix/store/8kyv8ffbfvksnqmm1kaz0llysg7dpn9z-cuda_nvcc-12.8.93 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ld6fk094jhhsnbip1406vrky9lmyxbax-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ld6fk094jhhsnbip1406vrky9lmyxbax-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/pg32mpjmckfs38anjzgyvk2ljfw12pb3-python3.13-torch-2.8.0-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/ck6gm6harngsdikmwpbjn2bmcf5gvhx4-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/7nn9pq71pfj35l1d9909vs91g9sb7wlw-python3.13-torch-2.7.1-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/j6r6hpjs8p5m4s3i8cqqavg62fd5z48g-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> CMake Warning (dev) at /nix/store/0vnarm4qjnj16dr3zj9kwq6bn79c0icn-cmake-3.31.7/share/cmake-3.31/Modules/FindPackageHandleStandardArgs.cmake:441 (message): +torch_harmonics_attn-torch-ext> The package name passed to `find_package_handle_standard_args` (nvtx3) does +torch_harmonics_attn-torch-ext> not match the name of the calling package (Caffe2). This can lead to +torch_harmonics_attn-torch-ext> problems in calling code that expects `find_package` result variables +torch_harmonics_attn-torch-ext> (e.g., `_FOUND`) to follow a certain pattern. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:184 (find_package_handle_standard_args) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> This warning is for project developers. Use -Wno-dev to suppress it. +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Could NOT find nvtx3 (missing: nvtx3_dir) +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:190 (message): +torch_harmonics_attn-torch-ext> Cannot find NVTX3, find old NVTX instead +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/utils.cmake:328 (message): +torch_harmonics_attn-torch-ext> In the future we will require one to explicitly pass TORCH_CUDA_ARCH_LIST +torch_harmonics_attn-torch-ext> to cmake instead of implicitly setting it as an env variable. This will +torch_harmonics_attn-torch-ext> become a FATAL_ERROR in future version of pytorch. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:337 (torch_cuda_get_nvcc_gencode_flag) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/qal2apcjwlw2p2kk05dwqdgzh8ml687l-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/6drs80sxjhskdki55g5k1dw0jzbd258w-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/6drs80sxjhskdki55g5k1dw0jzbd258w-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/6drs80sxjhskdki55g5k1dw0jzbd258w-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_101,code=sm_101;-gencode;arch=compute_120,code=sm_120 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/aikr517kmcd8r2nrrj70jq71d7352qiq-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/483ma0klnbln74izv5jiyila52bfwqxh-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/483ma0klnbln74izv5jiyila52bfwqxh-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/483ma0klnbln74izv5jiyila52bfwqxh-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_101,code=sm_101;-gencode;arch=compute_120,code=sm_120 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/dzz5brlw0xzs9hp3v8fvvwcvkmsr3ls9-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/8sicfhvzq84gnxiwybyjgp80pcynamzn-python3.13-torch-2.7.1-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/6drs80sxjhskdki55g5k1dw0jzbd258w-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/6drs80sxjhskdki55g5k1dw0jzbd258w-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/mrq1wi2biib2p1mks17g8g5sc4fd492r-python3.13-torch-2.8.0-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/483ma0klnbln74izv5jiyila52bfwqxh-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/483ma0klnbln74izv5jiyila52bfwqxh-python3.13-torch-2.8.0/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- PyTorch: Header version is: 12.8 +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/zccgvlbr93bhyia3sr9f2mddmkp2jyx7-python3.13-torch-2.8.0-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> -- Found Python: /nix/store/wirj6dihrpcch7idfd7jy4l0hqfsgkk1-python3-3.13.6-env/bin/python (found version "3.13.6") found components: Interpreter +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:146 (message): +torch_harmonics_attn-torch-ext> Failed to compute shorthash for libnvrtc.so +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> CMake Warning (dev) at /nix/store/0vnarm4qjnj16dr3zj9kwq6bn79c0icn-cmake-3.31.7/share/cmake-3.31/Modules/FindPackageHandleStandardArgs.cmake:441 (message): +torch_harmonics_attn-torch-ext> The package name passed to `find_package_handle_standard_args` (nvtx3) does +torch_harmonics_attn-torch-ext> not match the name of the calling package (Caffe2). This can lead to +torch_harmonics_attn-torch-ext> problems in calling code that expects `find_package` result variables +torch_harmonics_attn-torch-ext> (e.g., `_FOUND`) to follow a certain pattern. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:184 (find_package_handle_standard_args) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> This warning is for project developers. Use -Wno-dev to suppress it. +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Could NOT find nvtx3 (missing: nvtx3_dir) +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:190 (message): +torch_harmonics_attn-torch-ext> Cannot find NVTX3, find old NVTX instead +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- USE_CUDNN is set to 0. Compiling without cuDNN support +torch_harmonics_attn-torch-ext> -- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support +torch_harmonics_attn-torch-ext> -- USE_CUDSS is set to 0. Compiling without cuDSS support +torch_harmonics_attn-torch-ext> -- USE_CUFILE is set to 0. Compiling without cuFile support +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/utils.cmake:328 (message): +torch_harmonics_attn-torch-ext> In the future we will require one to explicitly pass TORCH_CUDA_ARCH_LIST +torch_harmonics_attn-torch-ext> to cmake instead of implicitly setting it as an env variable. This will +torch_harmonics_attn-torch-ext> become a FATAL_ERROR in future version of pytorch. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/public/cuda.cmake:337 (torch_cuda_get_nvcc_gencode_flag) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include) +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Added CUDA NVCC flags for: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_100,code=sm_100;-gencode;arch=compute_101,code=sm_101;-gencode;arch=compute_120,code=sm_120 +torch_harmonics_attn-torch-ext> CMake Warning at /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message): +torch_harmonics_attn-torch-ext> static library kineto_LIBRARY-NOTFOUND not found. +torch_harmonics_attn-torch-ext> Call Stack (most recent call first): +torch_harmonics_attn-torch-ext> /nix/store/4ww34a0xcdm3baaz7y2dnrr38r2yjwwx-python3.13-torch-2.7.1/lib/python3.13/site-packages/torch/share/cmake/Torch/TorchConfig.cmake:125 (append_torchlib_if_found) +torch_harmonics_attn-torch-ext> CMakeLists.txt:30 (find_package) +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Found Torch: /nix/store/35sj4in2ddx47klyg96qmkpd4vh8py94-python3.13-torch-2.7.1-lib/lib/libtorch.so +torch_harmonics_attn-torch-ext> -- CUDA target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> -- CUDA supported target architectures: 7.0;7.5;8.0;8.6;8.9;9.0;10.0;10.1;12.0 +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0 +torch_harmonics_attn-torch-ext> -- Configuring done (9.7s) +torch_harmonics_attn-torch-ext> -- Generating done (0.0s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0 +torch_harmonics_attn-torch-ext> -- Configuring done (9.7s) +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> -- Generating done (0.0s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0;10.0 +torch_harmonics_attn-torch-ext> -- Configuring done (9.8s) +torch_harmonics_attn-torch-ext> -- Generating done (0.0s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0;10.0 +torch_harmonics_attn-torch-ext> -- Configuring done (9.9s) +torch_harmonics_attn-torch-ext> -- Generating done (0.0s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0 +torch_harmonics_attn-torch-ext> -- Configuring done (10.2s) +torch_harmonics_attn-torch-ext> -- Generating done (0.1s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> -- Capabilities for kernel torch_harmonics_attn: 7.5;8.0;8.9;9.0;10.0 +torch_harmonics_attn-torch-ext> -- Configuring done (10.5s) +torch_harmonics_attn-torch-ext> -- Generating done (0.1s) +torch_harmonics_attn-torch-ext> CMake Warning: +torch_harmonics_attn-torch-ext> Manually-specified variables were not used by the project: +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> BUILD_TESTING +torch_harmonics_attn-torch-ext> CMAKE_EXPORT_NO_PACKAGE_REGISTRY +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_BINDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_DOCDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INCLUDEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_INFODIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LIBEXECDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_LOCALEDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_MANDIR +torch_harmonics_attn-torch-ext> CMAKE_INSTALL_SBINDIR +torch_harmonics_attn-torch-ext> CMAKE_POLICY_DEFAULT_CMP0025 +torch_harmonics_attn-torch-ext> CUDAToolkit_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PROTOC_EXE +torch_harmonics_attn-torch-ext> PYBIND11_PYTHONLIBS_OVERWRITE +torch_harmonics_attn-torch-ext> PYTHON_EXECUTABLE +torch_harmonics_attn-torch-ext> PYTHON_INCLUDE_DIR +torch_harmonics_attn-torch-ext> PYTHON_SITE_PACKAGES +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXE +torch_harmonics_attn-torch-ext> Protobuf_PROTOC_EXECUTABLE +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> +torch_harmonics_attn-torch-ext> -- Build files have been written to: /build/source/build +torch_harmonics_attn-torch-ext> cmake: enabled parallel building +torch_harmonics_attn-torch-ext> cmake: enabled parallel installing +torch_harmonics_attn-torch-ext> Running phase: buildPhase +torch_harmonics_attn-torch-ext> build flags: -j12 +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [1/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [2/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch-ext/torch_binding.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_bwd.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [3/7] Building CXX object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cpu_fwd.cpp.o +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 1 minutes 51 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/xa7svzp7i4m0wdi3wn2nad4dqh0zc2p0-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 1 minutes 52 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/hb6dhflq1d914b0xgi1sznh8ian5bspd-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 1 minutes 52 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/bq8dy48c7nhawb9pkc52km0wkgs0n810-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 1 minutes 54 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/ys125minvg9qxngb4aihpf2w95s896yq-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 2 minutes 48 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/j9i11lbhrmq6qdcn4fdj7l4ld5n43hdx-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +torch_harmonics_attn-torch-ext> [4/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_utils.cu.o +torch_harmonics_attn-torch-ext> [5/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_fwd.cu.o +torch_harmonics_attn-torch-ext> [6/7] Building CUDA object CMakeFiles/_torch_harmonics_attn_20251001150033.dir/torch_harmonics_attn/attention_cuda_bwd.cu.o +torch_harmonics_attn-torch-ext> [7/7] Linking CXX shared module _torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> buildPhase completed in 3 minutes 31 seconds +torch_harmonics_attn-torch-ext> Running phase: installPhase +torch_harmonics_attn-torch-ext> install flags: -j12 install +torch_harmonics_attn-torch-ext> [0/1] Install the project... +torch_harmonics_attn-torch-ext> -- Install configuration: "Release" +torch_harmonics_attn-torch-ext> -- Installing: /nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/_torch_harmonics_attn_20251001150033/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> Running phase: fixupPhase +torch_harmonics_attn-torch-ext> shrinking RPATHs of ELF executables and libraries in /nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> shrinking /nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext/torch_harmonics_attn/_torch_harmonics_attn_20251001150033.abi3.so +torch_harmonics_attn-torch-ext> checking for references to /build/ in /nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext... +torch_harmonics_attn-torch-ext> patching script interpreter paths in /nix/store/3sangpzxl6ng9f88bj24wrpisczlvrl5-torch_harmonics_attn-torch-ext +torch_harmonics_attn-torch-ext> Running phase: installCheckPhase +torch_harmonics_attn-torch-ext> no Makefile or custom installCheckPhase, doing nothing +torch_harmonics_attn-torch-ext> Checking of ABI compatibility +torch_harmonics_attn-torch-ext> 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 +torch_harmonics_attn-torch-ext> ✅ No compatibility issues found +torch_harmonics_attn-torch-ext> Checking loading kernel with get_kernel +torch_harmonics_attn-torch-ext> Check whether the kernel can be loaded with get-kernel: torch_harmonics_attn +building '/nix/store/nlrms1s9z75f22da29ac2vgmngzfdpay-torch-ext-bundle.drv'... +building '/nix/store/pg5s7j1lx0yqdbsxzkwm8jrr4bsls395-build-and-copy.drv'... diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89f77ea5538f5def6c1321ff4f3eae9c101efc15 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,14 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("s2_attention_bwd_dkvq_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor dy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> (Tensor, Tensor, Tensor)"); + ops.impl("s2_attention_bwd_dkvq_cuda", torch::kCUDA, &s2_attention_bwd_dkvq_cuda); + ops.def("s2_attention_fwd_cuda(Tensor kx, Tensor vx, Tensor qy, Tensor quad_weights, Tensor psi_col_idx, Tensor psi_row_off, int nlon_in, int nlat_out, int nlon_out) -> Tensor"); + ops.impl("s2_attention_fwd_cuda", torch::kCUDA, &s2_attention_fwd_cuda); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) \ No newline at end of file diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..ea74bb41d222caab5677c547cd8fcaee7473788c --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + + +std::tuple s2_attention_bwd_dkvq_cuda( + at::Tensor kx, + at::Tensor vx, + at::Tensor qy, + at::Tensor dy, + at::Tensor quad_weights, + at::Tensor psi_col_idx, + at::Tensor psi_row_off, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out +); + +torch::Tensor s2_attention_fwd_cuda( + at::Tensor kx, + at::Tensor vx, + at::Tensor qy, + at::Tensor quad_weights, + at::Tensor psi_col_idx, + at::Tensor psi_row_off, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out +); \ No newline at end of file diff --git a/torch-ext/torch_harmonics_attn/__init__.py b/torch-ext/torch_harmonics_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea18d9e7013dc08788764ddbf5d646c9772912be --- /dev/null +++ b/torch-ext/torch_harmonics_attn/__init__.py @@ -0,0 +1,10 @@ +from ._attn_utils import backward, forward, forward_optimized, backward_optimized, _neighborhood_s2_attention_fwd_torch, _neighborhood_s2_attention_bwd_torch + +__all__ = [ + "backward", + "forward", + "forward_optimized", + "backward_optimized", + "_neighborhood_s2_attention_fwd_torch", + "_neighborhood_s2_attention_bwd_torch", +] \ No newline at end of file diff --git a/torch-ext/torch_harmonics_attn/_attn_utils.py b/torch-ext/torch_harmonics_attn/_attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a7924244c63b5b7999512411e15243be7553556 --- /dev/null +++ b/torch-ext/torch_harmonics_attn/_attn_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +from typing import Union, Tuple + +import torch +import torch.nn.functional as F + +from ._ops import ops + +def backward(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_bwd_dkvq_cuda(kx, vx, qy, dy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def forward(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out): + return ops.s2_attention_fwd_cuda(kx, vx, qy, quad_weights, psi_col_idx, psi_row_off, nlon_in, nlat_out, nlon_out) + +def _setup_context_attention_backward(ctx, inputs, output): + k, v, q, wk, wv, wq, bk, bv, bq, quad_weights, col_idx, row_off, max_psi_nnz, nh, nlon_in, nlat_out, nlon_out = inputs + ctx.save_for_backward(col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq) + ctx.nh = nh + ctx.max_psi_nnz = max_psi_nnz + ctx.nlon_in = nlon_in + ctx.nlat_out = nlat_out + ctx.nlon_out = nlon_out + +def forward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + out_shape = (kw.shape[0], vw.shape[1], nlat_out, nlon_out) + return torch.empty(out_shape, dtype=kw.dtype, device=kw.device) + +def backward_default(kw: torch.Tensor, vw: torch.Tensor, qw: torch.Tensor, grad_output: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dk = torch.empty_like(kw) + dv = torch.empty_like(vw) + dq = torch.empty_like(qw) + return dk, dv, dq + + # forward +def forward_optimized(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + # convert to float32 + inp_dtype = kw.dtype + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + + output = forward(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + # convert back precision + output = output.to(dtype=inp_dtype) + + return output + +def backward_optimized(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + max_psi_nnz = ctx.max_psi_nnz + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + # save type and convert to float32 + kw_dtype = kw.dtype + vw_dtype = vw.dtype + qw_dtype = qw.dtype + + kw = kw.to(torch.float32).contiguous() + vw = vw.to(torch.float32).contiguous() + qw = qw.to(torch.float32).contiguous() + grad_output = grad_output.to(torch.float32).contiguous() + + dkw, dvw, dqw = backward(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + # weight grads + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + dkw = dkw.to(dtype=kw_dtype) + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + dvw = dvw.to(dtype=vw_dtype) + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + dqw = dqw.to(dtype=qw_dtype) + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None + +# torch kernels +# uses qdotk_max update trick to avoid two loops when computing the softmax +# see e.g., https://arxiv.org/abs/1805.02867 +# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/ +def _neighborhood_s2_attention_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + + + # prepare result tensor + out_shape = (qy.shape[0], vx.shape[1], nlat_out, nlon_out) + y = torch.zeros(out_shape, dtype=qy.dtype, device=qy.device) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_sum = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + qdotk_max = torch.zeros((y.shape[0],), dtype=y.dtype, device=y.device) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi + wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wip = kx[:, :, hi, wip] + qdotk = torch.sum(q_ho_wo * k_hi_wip, dim=1) + + # tmp max + qdotk_max_tmp = torch.maximum(qdotk_max, qdotk) + + # alpha sum update + alpha = torch.exp(qdotk - qdotk_max_tmp) * quad_weights[hi] + alpha_sum = alpha + alpha_sum * torch.exp(qdotk_max - qdotk_max_tmp) + # update output + y[:,:,ho,wo] = y[:,:,ho,wo] * torch.exp(qdotk_max - qdotk_max_tmp).unsqueeze(1) + alpha[:, None] * vx[:,:,hi,wip] + + # define new max + qdotk_max = qdotk_max_tmp + + y[:,:,ho,wo] = y[:,:,ho,wo] / alpha_sum[:, None] + + return y + +# Explicit gradient w.r.t. vx: dM/dv +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dv_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, Cout, Ho, Wo + # quad_weights: Hi + # output + # dvx: B, Cout, Hi, Wi + + dvx = torch.zeros_like(vx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + alpha_nz[:,idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha_nz[:,idz-zstart] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + dvx[:,:,hi, wip] += (alpha_nz[:, None, idz-zstart] / alpha_sum[:, None]) * dy[:,:,ho,wo] + + return dvx + + +# Explicit gradient w.r.t. kx: dM/dk +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dk_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dkx: B, C, Hi, Wi + + dkx = torch.zeros_like(kx) + batch_size = dy.shape[0] + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + integral = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hj_wjp = kx[:, :, hj, wjp] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hj_wjp, dim=1) + + qdotk_max, _ = torch.max(qdotk_nz, dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hj = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wj = nz_col_idx % nlon_in + wjp = (wj+wo) % nlon_in + + alpha[:, idz-zstart] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hj] + alpha_sum[:] += alpha[:, idz-zstart] + + # input dot + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hj, wjp], dim=1) + + # integral term + integral[:] += alpha[:, idz-zstart] * gdotv[:] + + integral[:] = integral[:] / alpha_sum[:] + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + # compute correlation & softmax numerator + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + + dkx[:,:,hi,wip] += qy[:, :, ho, wo] * (alpha[:, None, idz-zstart] / alpha_sum[:, None]) * (gdotv[:, None] - integral[:, None]) + + return dkx + +# Explicit gradient w.r.t. qy: dM/dq +# provided as a reference for CUDA & other hand-written gradients +def _neighborhood_s2_attention_bwd_dq_torch(kx: torch.Tensor, vx: torch.Tensor, qy: torch.Tensor, dy: torch.Tensor, + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + nlon_in: int, nlat_out: int, nlon_out: int): + + # shapes: + # input + # kx: B, C, Hi, Wi + # vx: B, Cout, Hi, Wi + # qy: B, C, Ho, Wo + # quad_weights: Hi + # output + # dq: B, C, Ho, Wo + + batch_size = dy.shape[0] + channels_in = kx.shape[1] + channels_out = vx.shape[1] + + dqy = torch.zeros_like(qy) + + for ho in range(nlat_out): + + # get number of nonzeros + zstart = row_off[ho] + zend = row_off[ho+1] + + for wo in range(nlon_out): + + alpha = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + qdotk_nz = torch.zeros((batch_size, zend-zstart), dtype=dy.dtype, device=dy.device) + alpha_k = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_vw = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_kvw = torch.zeros((batch_size, channels_in), dtype=dy.dtype, device=dy.device) + alpha_sum = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + alpha_sum2 = torch.zeros((batch_size,), dtype=dy.dtype, device=dy.device) + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + idz_i = idz-zstart + + # compute correlation & softmax numerator + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + qdotk_nz[:,idz-zstart] = torch.sum(q_ho_wo * k_hi_wi, dim=1) + + qdotk_max,_ = qdotk_nz.max(dim=1) + + for idz in range(zstart, zend): + nz_col_idx = col_idx[idz] + + # compute input indices from psi datastructure + hi = nz_col_idx // nlon_in + # account for output shift and ensure positive index due to circular condition + wi = nz_col_idx % nlon_in + wip = (wi+wo) % nlon_in + + q_ho_wo = qy[:, :, ho, wo] + k_hi_wi = kx[:, :, hi, wip] + idz_i = idz-zstart + alpha[:, idz_i] = torch.exp(qdotk_nz[:,idz-zstart] - qdotk_max) * quad_weights[hi] + alpha_sum[:] += alpha[:, idz_i] + + gdotv = torch.sum(dy[:,:,ho, wo] * vx[:,:,hi, wip], dim=1) + alpha_k[:,:] += alpha[:, None, idz_i] * k_hi_wi + alpha_vw[:] += alpha[:, idz_i] * gdotv[:] + alpha_kvw[:,:] += alpha[:, None, idz_i] * k_hi_wi * gdotv[:,None] + + dqy[:,:,ho,wo] = (alpha_kvw * alpha_sum[:,None] - alpha_vw[:, None] * alpha_k) / (alpha_sum[:,None] * alpha_sum[:,None]) + + return dqy + +def _neighborhood_s2_attention_torch(k: torch.Tensor, v: torch.Tensor, q: torch.Tensor, + wk: torch.Tensor, wv: torch.Tensor, wq: torch.Tensor, + bk: Union[torch.Tensor, None], bv: Union[torch.Tensor, None], bq: Union[torch.Tensor, None], + quad_weights: torch.Tensor, col_idx: torch.Tensor, row_off: torch.Tensor, + max_psi_nnz: int, nh: int, nlon_in: int, nlat_out: int, nlon_out: int) -> torch.Tensor: + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + + kw = kw.to(torch.float32) + vw = vw.to(torch.float32) + qw = qw.to(torch.float32) + + output = _neighborhood_s2_attention_fwd_torch(kw, vw, qw, quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + + _, C, H, W = output.shape + output = output.reshape(B, -1, H, W) + + return output + +def _neighborhood_s2_attention_bwd_torch(ctx, grad_output): + col_idx, row_off, quad_weights, k, v, q, wk, wv, wq, bk, bv, bq = ctx.saved_tensors + nh = ctx.nh + nlon_in = ctx.nlon_in + nlat_out = ctx.nlat_out + nlon_out = ctx.nlon_out + + # check if we need the grads at all + k_needs_grad = ctx.needs_input_grad[0] + v_needs_grad = ctx.needs_input_grad[1] + q_needs_grad = ctx.needs_input_grad[2] + wk_needs_grad = ctx.needs_input_grad[3] + wv_needs_grad = ctx.needs_input_grad[4] + wq_needs_grad = ctx.needs_input_grad[5] + bk_needs_grad = ctx.needs_input_grad[6] + bv_needs_grad = ctx.needs_input_grad[7] + bq_needs_grad = ctx.needs_input_grad[8] + + kw = F.conv2d(k, weight=wk, bias=bk) + vw = F.conv2d(v, weight=wv, bias=bv) + qw = F.conv2d(q, weight=wq, bias=bq) + + # reshape, folding num heads into batch dim + B, _, H, W = kw.shape + kw = kw.reshape(B*nh, -1, H, W) + B, _, H, W = vw.shape + vw = vw.reshape(B*nh, -1, H, W) + B, _, H, W = qw.shape + qw = qw.reshape(B*nh, -1, H, W) + B, _, H, W = grad_output.shape + grad_output = grad_output.reshape(B*nh, -1, H, W) + + if v_needs_grad or wv_needs_grad or bv_needs_grad: + dvw = _neighborhood_s2_attention_bwd_dv_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dvw.shape + dvw = dvw.reshape(B, -1, H, W) + else: + dvw = None + + if k_needs_grad or wk_needs_grad or bk_needs_grad: + dkw = _neighborhood_s2_attention_bwd_dk_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dkw.shape + dkw = dkw.reshape(B, -1, H, W) + else: + dkw = None + + if q_needs_grad or wq_needs_grad or bq_needs_grad: + dqw = _neighborhood_s2_attention_bwd_dq_torch(kw, vw, qw, grad_output, + quad_weights, + col_idx, row_off, + nlon_in, nlat_out, nlon_out) + _, C, H, W = dqw.shape + dqw = dqw.reshape(B, -1, H, W) + else: + dqw = None + + # input grads + if v_needs_grad: + dv = torch.nn.functional.conv2d(dvw, weight=wv.permute([1,0,2,3]), bias=None) + else: + dv = None + + if k_needs_grad: + dk = torch.nn.functional.conv2d(dkw, weight=wk.permute([1,0,2,3]), bias=None) + else: + dk = None + + if q_needs_grad: + dq = torch.nn.functional.conv2d(dqw, weight=wq.permute([1,0,2,3]), bias=None) + else: + dq = None + + # weight grads + if wv_needs_grad: + dwv = torch.einsum("bchw,bfhw->cf", dvw, v).reshape(*wv.shape).contiguous() + else: + dwv = None + + if wk_needs_grad: + dwk = torch.einsum("bchw,bfhw->cf", dkw, k).reshape(*wk.shape).contiguous() + else: + dwk = None + + if wq_needs_grad: + dwq = torch.einsum("bchw,bfhw->cf", dqw, q).reshape(*wq.shape).contiguous() + else: + dwq = None + + # bias grads: + if bv_needs_grad: + dbv = torch.sum(dvw, dim=(0,2,3)) + else: + dbv = None + + if bk_needs_grad: + dbk = torch.sum(dkw, dim=(0,2,3)) + else: + dbk = None + + if bq_needs_grad: + dbq = torch.sum(dqw, dim=(0,2,3)) + else: + dbq = None + + return dk, dv, dq, dwk, dwv, dwq, dbk, dbv, dbq, \ + None, None, None, None, None, None, None, None diff --git a/torch_harmonics_attn/attention.h b/torch_harmonics_attn/attention.h new file mode 100644 index 0000000000000000000000000000000000000000..7198ad8c1bb522539419d68fd4a5e2646b097fc8 --- /dev/null +++ b/torch_harmonics_attn/attention.h @@ -0,0 +1,44 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2024 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include +#include + +#define CHECK_CPU_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCPU) +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x) +#define CHECK_CPU_INPUT_TENSOR(x) \ + CHECK_CPU_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) diff --git a/torch_harmonics_attn/attention_cpu.h b/torch_harmonics_attn/attention_cpu.h new file mode 100644 index 0000000000000000000000000000000000000000..f96e7a7b89b196c2173e05a814ed3dc467235931 --- /dev/null +++ b/torch_harmonics_attn/attention_cpu.h @@ -0,0 +1,315 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include "attention.h" +#include +#include + +#define CACHE_BLOCK_SIZE (64) + +namespace attention_kernels { + + template + void s2_attn_fwd_kernel( + const torch::PackedTensorAccessor64 kx_arr, + const torch::PackedTensorAccessor64 vx_arr, + const torch::PackedTensorAccessor64 qy_arr, + const torch::PackedTensorAccessor64 quad_weights_arr, + const torch::PackedTensorAccessor64 col_idx_arr, + const torch::PackedTensorAccessor64 roff_arr, + torch::PackedTensorAccessor64 y_arr, + const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, + const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + + // some parameters + const int64_t block_wo = CACHE_BLOCK_SIZE; + const int64_t nblock_wo = static_cast((nlon_out + block_wo - 1) / block_wo); + + #pragma omp parallel for collapse(4) + for (int64_t b = 0; b < batch_size; b++) { + for (int64_t co = 0; co < nchannels_out; co++) { + for (int64_t ho = 0; ho < nlat_out; ho++) { + for (int64_t bwo = 0; bwo < nblock_wo; bwo++) { + + // compute block start and end + int64_t wo_start = bwo * block_wo; + int64_t wo_end = std::min(nlon_out, wo_start + block_wo); + + // get number of nonzeros + int64_t zstart = roff_arr[ho]; + int64_t zend = roff_arr[ho+1]; + + // init temp aray to zero + std::array alpha_sum; + std::array qdotk_max; + std::array y_tmp; + for (int64_t wob = 0; wob < block_wo; wob++) { + alpha_sum[wob] = 0.0; + qdotk_max[wob] = -std::numeric_limits::max(); + y_tmp[wob] = 0.0; + } + + // loop over nonzeros + for (int64_t idz = zstart; idz < zend; idz++) { + // get column index + int64_t nz_col_idx = col_idx_arr[idz]; + + // compute input indices from psi datastructure + int64_t hi = static_cast(nz_col_idx / nlon_in); + // account for output shift and ensure positive index due to circular condition + int64_t wi = nz_col_idx % nlon_in; + + // loop over wo block + for (int64_t wo = wo_start; wo < wo_end; wo++) { + int64_t wip = (wi + wo) % nlon_in; + + float qdotk = 0.0; + //#pragma omp simd reduction(+:qdotk) + for (int64_t ci = 0; ci < nchannels_in; ci++) { + qdotk += static_cast(qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]); + } + + // update tmp max + float qdotk_max_tmp = std::max(qdotk_max[wo-wo_start], qdotk); + + // alpha sum update + float alpha = std::exp(qdotk - qdotk_max_tmp) * static_cast(quad_weights_arr[hi]); + alpha_sum[wo-wo_start] = alpha + alpha_sum[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp); + + // update output + y_tmp[wo-wo_start] = y_tmp[wo-wo_start] * std::exp(qdotk_max[wo-wo_start] - qdotk_max_tmp) + alpha * static_cast(vx_arr[b][co][hi][wip]); + + // define new max + qdotk_max[wo-wo_start] = qdotk_max_tmp; + } + } + + // update output + for (int64_t wo = wo_start; wo < wo_end; wo++) { + y_arr[b][co][ho][wo] = static_cast(y_tmp[wo-wo_start] / alpha_sum[wo-wo_start]); + } + } + } + } + } + } + + template + void s2_attn_bwd_kernel( + const torch::PackedTensorAccessor64 kx_arr, + const torch::PackedTensorAccessor64 vx_arr, + const torch::PackedTensorAccessor64 qy_arr, + const torch::PackedTensorAccessor64 dy_arr, + const torch::PackedTensorAccessor64 quad_weights_arr, + const torch::PackedTensorAccessor64 col_idx_arr, + const torch::PackedTensorAccessor64 roff_arr, + torch::PackedTensorAccessor64 dqy_arr, + torch::PackedTensorAccessor64 dvx_arr, + torch::PackedTensorAccessor64 dkx_arr, + const int64_t nlon_in, const int64_t nlat_out, const int64_t nlon_out, + const int64_t batch_size, const int64_t nchannels_in, const int64_t nchannels_out) { + + // compute dqy and dkx + #pragma omp parallel for collapse(2) + for (int64_t b = 0; b < batch_size; b++) { + for (int64_t ci = 0; ci < nchannels_in; ci++) { + + for (int64_t ho = 0; ho < nlat_out; ho++) { + + // get number of nonzeros + int64_t zstart = roff_arr[ho]; + int64_t zend = roff_arr[ho+1]; + + for (int64_t wo = 0; wo < nlon_out; wo++) { + + // required for all grads + std::vector qdotk_nz(zend-zstart); + float qdotk_max = -std::numeric_limits::max(); + std::vector alpha_nz(zend-zstart); + float alpha_sum = 0.0; + + // required for dkx + float alpha_gdotv = 0.0; + + // required for dqy + float alpha_k = 0.0; + float alpha_k_gdotv = 0.0; + + for (int64_t idz = zstart; idz < zend; idz++) { + int64_t nz_col_idx = col_idx_arr[idz]; + + // compute input indices from psi datastructure + int64_t hi = static_cast(nz_col_idx / nlon_in); + // account for output shift and ensure positive index due to circular condition + int64_t wi = nz_col_idx % nlon_in; + int64_t wip = (wi+wo) % nlon_in; + + // compute correlation & softmax numerator + qdotk_nz[idz-zstart] = 0.0; + for (int64_t cit = 0; cit < nchannels_in; cit++) { + qdotk_nz[idz-zstart] += qy_arr[b][cit][ho][wo] * kx_arr[b][cit][hi][wip]; + } + + // tmp max and discount + float qdotk_max_tmp = std::max(qdotk_max, qdotk_nz[idz-zstart]); + float discount = std::exp(qdotk_max - qdotk_max_tmp); + + // alpha update + alpha_nz[idz-zstart] = std::exp(qdotk_nz[idz-zstart] - qdotk_max_tmp) * quad_weights_arr[hi]; + alpha_sum = alpha_nz[idz-zstart] + alpha_sum * discount; + + // dkx: input dot + float gdotv = 0.0; + for (int64_t cot = 0; cot < nchannels_out; cot++) { + gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + } + float alpha_gdotv_tmp = alpha_nz[idz-zstart] * gdotv; + alpha_gdotv = alpha_gdotv_tmp + alpha_gdotv * discount; + + // dqy: alpha_k + alpha_k = alpha_nz[idz-zstart] * kx_arr[b][ci][hi][wip] + alpha_k * discount; + + // dqy: alpha_k_gdotv + alpha_k_gdotv = alpha_gdotv_tmp * kx_arr[b][ci][hi][wip] + alpha_k_gdotv * discount; + + // define new max + qdotk_max = qdotk_max_tmp; + } + + // normalization + alpha_gdotv = alpha_gdotv / alpha_sum; + alpha_k = alpha_k / alpha_sum; + alpha_k_gdotv = alpha_k_gdotv / alpha_sum; + + // dqy: update + dqy_arr[b][ci][ho][wo] = (alpha_k_gdotv - alpha_gdotv * alpha_k); + + for (int64_t idz = zstart; idz < zend; idz++) { + int64_t nz_col_idx = col_idx_arr[idz]; + + // compute input indices from psi datastructure + int64_t hi = static_cast(nz_col_idx / nlon_in); + // account for output shift and ensure positive index due to circular condition + int64_t wi = nz_col_idx % nlon_in; + int64_t wip = (wi+wo) % nlon_in; + + // dkx: alpha normalization + float alpha_norm = std::exp(qdotk_nz[idz-zstart] - qdotk_max) * quad_weights_arr[hi] / alpha_sum; + + // dkx: input dot + float gdotv = 0.0; + for (int64_t cot = 0; cot < nchannels_out; cot++) { + gdotv += dy_arr[b][cot][ho][wo] * vx_arr[b][cot][hi][wip]; + } + + // dkx: update + dkx_arr[b][ci][hi][wip] += qy_arr[b][ci][ho][wo] * alpha_norm * (gdotv - alpha_gdotv); + } + } + } + } + } + + // compute dvx + #pragma omp parallel for collapse(2) + for (int64_t b = 0; b < batch_size; b++) { + for (int64_t co = 0; co < nchannels_out; co++) { + + for (int64_t ho = 0; ho < nlat_out; ho++) { + + // get number of nonzeros + int64_t zstart = roff_arr[ho]; + int64_t zend = roff_arr[ho+1]; + + for (int64_t wo = 0; wo < nlon_out; wo++) { + + // required for all grads + std::vector qdotk_nz(zend-zstart); + float qdotk_max = -std::numeric_limits::max(); + std::vector alpha_nz(zend-zstart); + float alpha_sum = 0.0; + + for (int64_t idz = zstart; idz < zend; idz++) { + int64_t nz_col_idx = col_idx_arr[idz]; + + // compute input indices from psi datastructure + int64_t hi = static_cast(nz_col_idx / nlon_in); + // account for output shift and ensure positive index due to circular condition + int64_t wi = nz_col_idx % nlon_in; + int64_t wip = (wi+wo) % nlon_in; + + // compute correlation & softmax numerator + qdotk_nz[idz-zstart] = 0.0; + for (int64_t ci = 0; ci < nchannels_in; ci++) { + qdotk_nz[idz-zstart] += qy_arr[b][ci][ho][wo] * kx_arr[b][ci][hi][wip]; + } + + // tmp max and discount + float qdotk_max_tmp = std::max(qdotk_max, qdotk_nz[idz-zstart]); + float discount = std::exp(qdotk_max - qdotk_max_tmp); + + // alpha update + alpha_nz[idz-zstart] = std::exp(qdotk_nz[idz-zstart] - qdotk_max_tmp) * quad_weights_arr[hi]; + alpha_sum = alpha_nz[idz-zstart] + alpha_sum * discount; + + // define new max + qdotk_max = qdotk_max_tmp; + } + + for (int64_t idz = zstart; idz < zend; idz++) { + int64_t nz_col_idx = col_idx_arr[idz]; + + // compute input indices from psi datastructure + int64_t hi = static_cast(nz_col_idx / nlon_in); + // account for output shift and ensure positive index due to circular condition + int64_t wi = nz_col_idx % nlon_in; + int64_t wip = (wi+wo) % nlon_in; + + // recompute alpha + float alpha_norm = std::exp(qdotk_nz[idz-zstart] - qdotk_max) * quad_weights_arr[hi] / alpha_sum; + dvx_arr[b][co][hi][wip] += alpha_norm * dy_arr[b][co][ho][wo]; + } + } + } + } + } + } + + + torch::Tensor s2_attention_fwd_cpu(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, + at::Tensor col_idx, at::Tensor row_off, int64_t nlon_in, int64_t nlat_out, int64_t nlon_out); + + std::tuple s2_attention_bwd_cpu(torch::Tensor kx, torch::Tensor vx, torch::Tensor qy, torch::Tensor dy, + torch::Tensor quad_weights, torch::Tensor col_idx, torch::Tensor row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out); + +} \ No newline at end of file diff --git a/torch_harmonics_attn/attention_cpu_bwd.cpp b/torch_harmonics_attn/attention_cpu_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f7ec9b6d393eba37c2c42b0f7197642111a370f1 --- /dev/null +++ b/torch_harmonics_attn/attention_cpu_bwd.cpp @@ -0,0 +1,113 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "attention_cpu.h" + +using namespace torch::indexing; + +namespace attention_kernels { + +std::tuple s2_attention_bwd_cpu(torch::Tensor kx, torch::Tensor vx, torch::Tensor qy, torch::Tensor dy, + torch::Tensor quad_weights, torch::Tensor col_idx, torch::Tensor row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out) { + + // shapes: + // input + // kx: B, C, Hi, Wi + // vx: B, C, Hi, Wi + // qy: B, C, Ho, Wo + // quad_weights: Hi + // output + // dkx: B, C, Hi, Wi + // dvx: B, C, Hi, Wi + // dqy: B, C, Ho, Wo + + // sanity checks + CHECK_CPU_INPUT_TENSOR(kx); + CHECK_CPU_INPUT_TENSOR(vx); + CHECK_CPU_INPUT_TENSOR(qy); + CHECK_CPU_INPUT_TENSOR(dy); + CHECK_CPU_INPUT_TENSOR(quad_weights); + CHECK_CPU_INPUT_TENSOR(col_idx); + CHECK_CPU_INPUT_TENSOR(row_off); + + // change to channels first: + bool kx_is_channels_last = kx.strides()[1] == 1; + bool vx_is_channels_last = vx.strides()[1] == 1; + bool qy_is_channels_last = qy.strides()[1] == 1; + bool dy_is_channels_last = dy.strides()[1] == 1; + + if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } + if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } + if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } + if (!dy_is_channels_last) { dy = dy.contiguous(at::MemoryFormat::ChannelsLast); } + + auto dkx = torch::zeros_like(kx); + auto dvx = torch::zeros_like(vx); + auto dqy = torch::zeros_like(qy); + + // some parameters + const int64_t batch_size = kx.size(0); + const int64_t nchannels_out = vx.size(1); + const int64_t nchannels_in = qy.size(1); + + // extract accessors + auto kx_arr = kx.packed_accessor64(); + auto vx_arr = vx.packed_accessor64(); + auto qy_arr = qy.packed_accessor64(); + auto dy_arr = dy.packed_accessor64(); + + auto quad_weights_arr = quad_weights.packed_accessor64(); + auto col_idx_arr = col_idx.packed_accessor64(); + auto roff_arr = row_off.packed_accessor64(); + + auto dqy_arr = dqy.packed_accessor64(); + auto dvx_arr = dvx.packed_accessor64(); + auto dkx_arr = dkx.packed_accessor64(); + + s2_attn_bwd_kernel(kx_arr, vx_arr, qy_arr, dy_arr, + quad_weights_arr, col_idx_arr, roff_arr, dqy_arr, dvx_arr, dkx_arr, + nlon_in, nlat_out, nlon_out, + batch_size, nchannels_in, nchannels_out); + + // permute back + if (!qy_is_channels_last) { dqy = dqy.contiguous(at::MemoryFormat::Contiguous); } + if (!vx_is_channels_last) { dvx = dvx.contiguous(at::MemoryFormat::Contiguous); } + if (!kx_is_channels_last) { dkx = dkx.contiguous(at::MemoryFormat::Contiguous); } + + return std::make_tuple(dkx, dvx, dqy); +} + +TORCH_LIBRARY_IMPL(attention_kernels, CPU, m) +{ + m.impl("backward", &s2_attention_bwd_cpu); +} + +} \ No newline at end of file diff --git a/torch_harmonics_attn/attention_cpu_fwd.cpp b/torch_harmonics_attn/attention_cpu_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3abbb7b366764c16ff446ed8777ad82d5e072b5d --- /dev/null +++ b/torch_harmonics_attn/attention_cpu_fwd.cpp @@ -0,0 +1,89 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "attention_cpu.h" + +using namespace torch::indexing; + +namespace attention_kernels { + + torch::Tensor s2_attention_fwd_cpu(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, + at::Tensor col_idx, at::Tensor row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out) { + // sanity checks + CHECK_CPU_INPUT_TENSOR(kx); + CHECK_CPU_INPUT_TENSOR(vx); + CHECK_CPU_INPUT_TENSOR(qy); + CHECK_CPU_INPUT_TENSOR(quad_weights); + CHECK_CPU_INPUT_TENSOR(col_idx); + CHECK_CPU_INPUT_TENSOR(row_off); + + // change to channels first: + bool kx_is_channels_last = kx.strides()[1] == 1; + bool vx_is_channels_last = vx.strides()[1] == 1; + bool qy_is_channels_last = qy.strides()[1] == 1; + + if (!kx_is_channels_last) { kx = kx.contiguous(at::MemoryFormat::ChannelsLast); } + if (!vx_is_channels_last) { vx = vx.contiguous(at::MemoryFormat::ChannelsLast); } + if (!qy_is_channels_last) { qy = qy.contiguous(at::MemoryFormat::ChannelsLast); } + + // some parameters + const int64_t batch_size = kx.size(0); + const int64_t nchannels_out = vx.size(1); + const int64_t nchannels_in = qy.size(1); + + // prepare result tensor + auto y = torch::zeros({batch_size, nchannels_out, nlat_out, nlon_out}, qy.options()); + + // extract accessors + auto roff_arr = row_off.packed_accessor64(); + auto col_idx_arr = col_idx.packed_accessor64(); + auto quad_weights_arr = quad_weights.packed_accessor64(); + auto vx_arr = vx.packed_accessor64(); + auto qy_arr = qy.packed_accessor64(); + auto kx_arr = kx.packed_accessor64(); + auto y_arr = y.packed_accessor64(); + + s2_attn_fwd_kernel(kx_arr, vx_arr, qy_arr, quad_weights_arr, col_idx_arr, roff_arr, y_arr, + nlon_in, nlat_out, nlon_out, batch_size, nchannels_in, nchannels_out); + + // permute back + if (!qy_is_channels_last) { y = y.contiguous(at::MemoryFormat::Contiguous); } + + return y; + } + + // Implement the operators: CPU + TORCH_LIBRARY_IMPL(attention_kernels, CPU, m) + { + m.impl("forward", &s2_attention_fwd_cpu); + } + +} \ No newline at end of file diff --git a/torch_harmonics_attn/attention_cuda.cuh b/torch_harmonics_attn/attention_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..fb099cb628a15d018d93d33604c20108a45ecbb5 --- /dev/null +++ b/torch_harmonics_attn/attention_cuda.cuh @@ -0,0 +1,51 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include + +#define CHECK_CUDA_TENSOR(x) TORCH_INTERNAL_ASSERT(x.device().type() == torch::kCUDA) +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_INTERNAL_ASSERT(x.is_contiguous() || x.is_contiguous(at::MemoryFormat::ChannelsLast)) +#define CHECK_CUDA_INPUT_TENSOR(x) \ + CHECK_CUDA_TENSOR(x); \ + CHECK_CONTIGUOUS_TENSOR(x) + + +torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, at::Tensor quad_weights, + at::Tensor psi_col_idx, at::Tensor psi_row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out); + +std::tuple s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, + at::Tensor dy, at::Tensor quad_weights, + at::Tensor psi_col_idx, at::Tensor psi_row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out); diff --git a/torch_harmonics_attn/attention_cuda_bwd.cu b/torch_harmonics_attn/attention_cuda_bwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..f1548a7a4101b2f2869086a689377c03d15189ab --- /dev/null +++ b/torch_harmonics_attn/attention_cuda_bwd.cu @@ -0,0 +1,1066 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "attention_cuda.cuh" +#include "c10/core/MemoryFormat.h" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cudamacro.h" +#include "attention_cuda_utils.cuh" + +#include +#include +#include + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (16) + + +#if 0 +class ScopeTimer +{ + public: + explicit ScopeTimer(const std::string &label = "") : + label_(label), start_(std::chrono::high_resolution_clock::now()) + { + } + + ~ScopeTimer() + { + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast(end - start_); + std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl; + } + + private: + std::string label_; + std::chrono::high_resolution_clock::time_point start_; +}; + +// easier to understand version of manual shfl_xor_sync, performance appears similar +static __device__ float __warp_sum_cub(float val) +{ + // use cub to reduce within a warp + __shared__ typename cub::WarpReduce::TempStorage temp_storage; + + // 1. Compute sum (initially only in lane 0) + float sum = cub::WarpReduce(temp_storage).Sum(val); + // 2. Broadcast sum to all threads + sum = __shfl_sync(0xFFFFFFFF, sum, 0); + return sum; +} + +// This kernel computes the backward pass for the S2 attention mechanism, using +// shared memory as a cache and one warp per output point, warp-parallel over +// channels, which should be layed out in the fastest dimension for coalesced +// memory access. +template +__global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel( + int num_channels, int nlon_in, int nlat_out, int nlon_out, + const torch::PackedTensorAccessor32 kx, + const torch::PackedTensorAccessor32 vx, + const torch::PackedTensorAccessor32 qy, + const torch::PackedTensorAccessor32 dy, + torch::PackedTensorAccessor32 dydk, + torch::PackedTensorAccessor32 dydv, + torch::PackedTensorAccessor32 dydq, + const torch::PackedTensorAccessor64 psi_col_idx, + const torch::PackedTensorAccessor64 psi_row_offset, + const torch::PackedTensorAccessor32 quad_weights) +{ + + extern __shared__ float sh[]; + float *sh_alpha_k = sh + threadIdx.y * num_channels * 5; + float *sh_alpha_vw = sh_alpha_k + num_channels; + float *sh_alpha_kvw = sh_alpha_vw + num_channels; + float *sh_dy = sh_alpha_kvw + num_channels; + float *sh_qy = sh_dy + num_channels; + // (optionally, could use more shared memory for other intermediates) + + const uint64_t batchId = blockIdx.y; + const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; + if (wid >= uint64_t(nlat_out) * nlon_in) return; + const int tidx = threadIdx.x; + const int ho = wid / nlon_out; + const int wo = wid - (ho * nlon_out); + + // Zero shared memory + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + sh_alpha_k[chan] = 0.0f; + sh_alpha_vw[chan] = 0.0f; + sh_alpha_kvw[chan] = 0.0f; + sh_dy[chan] = dy[batchId][chan][ho][wo]; + sh_qy[chan] = qy[batchId][chan][ho][wo]; + } + float alpha_sum = 0.0f; + float qdotk_max = -FLT_MAX; + float integral = 0.0f; + __syncthreads(); + + const int64_t rbeg = psi_row_offset[ho]; + const int64_t rend = psi_row_offset[ho + 1]; + const int rlen = rend - rbeg; + + // 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max. + for (int off = 0; off < rlen; off++) { + const int64_t col = psi_col_idx[rbeg + off]; + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + float qdotk = 0.0f, gdotv = 0.0f; + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip]; + gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; + } + qdotk = __warp_sum_cub(qdotk); + gdotv = __warp_sum_cub(gdotv); + float qdotk_max_tmp = max(qdotk_max, qdotk); + float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; + float max_correction = expf(qdotk_max - qdotk_max_tmp); + alpha_sum = alpha_sum * max_correction + alpha_inz; + integral = integral * max_correction + alpha_inz * gdotv; + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + float kxval = kx[batchId][chan][hi][wip]; + sh_alpha_k[chan] = sh_alpha_k[chan] * max_correction + alpha_inz * kxval; + sh_alpha_vw[chan] = sh_alpha_vw[chan] * max_correction + alpha_inz * gdotv; + sh_alpha_kvw[chan] = sh_alpha_kvw[chan] * max_correction + alpha_inz * kxval * gdotv; + } + qdotk_max = qdotk_max_tmp; + } + + integral /= alpha_sum; + + // Write dydq + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + dydq[batchId][chan][ho][wo] + = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum); + } + + // Third pass: accumulate gradients for k and v + for (int off = 0; off < rlen; off++) { + const int64_t col = psi_col_idx[rbeg + off]; + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + float qdotk = 0.0f, gdotv = 0.0f; + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip]; + gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip]; + } + qdotk = __warp_sum_cub(qdotk); + gdotv = __warp_sum_cub(gdotv); + float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; + for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) { + float qyval = qy[batchId][chan][ho][wo]; + float dyval = sh_dy[chan]; + atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral)); + atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval); + } + } +} +#endif + +// BEGIN backward kernels and functions + +// called with (blockDim.x=32 and blockDim.y>1, BDIM=blockDim.x*blockDim.y) +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X) +void s2_attn_bwd_generic_vec_k(int nchans_in, // no. of FLOATV_T elements along channel dim + int nchans_out, // no. of FLOATV_T elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, // [batch][nlat_in][nlon_in][nchan_in] + const FLOATV_T *__restrict__ vx, // [batch][nlat_in][nlon_in][nchan_out] + const FLOATV_T *__restrict__ qy, // [batch][nlat_out][nlon_out][nchan_in] + const FLOATV_T *__restrict__ dy, // [batch][nlat_out][nlon_out][nchan_out] + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + FLOATV_T *__restrict__ dkx, // [batch][nlat_in][nlon_in][nchan_in] + FLOATV_T *__restrict__ dvx, // [batch][nlat_in][nlon_in][nchan_out] + FLOATV_T *__restrict__ dqy) { // [batch][nlat_out][nlon_out][nchan_in] + + extern __shared__ __align__(sizeof(float4)) float shext[]; + + // for dqy + FLOATV_T *sh_alpha_k__ = reinterpret_cast(shext) + threadIdx.y * (nchans_in*4+nchans_out); + FLOATV_T *sh_alpha_vw_ = sh_alpha_k__ + nchans_in; + FLOATV_T *sh_alpha_kvw = sh_alpha_vw_ + nchans_in; + + FLOATV_T *sh_dy = sh_alpha_kvw + nchans_in; + FLOATV_T *sh_qy = sh_dy + nchans_out; + // sh_alpha_k__[nchan_in], sh_alpha_vw_[nchan_in], sh_alpha_kvw[nchan_in] + // sh_dy[nchan_out], sh_qy[nchan_in] + + const int batch = blockIdx.y; + + const uint64_t wid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; + if (wid >= uint64_t(nlat_out)*nlon_in) { + return; + } + + const int tidx = threadIdx.x; + + // use permuted rows + const int h = wid / nlon_out; + const int wo = wid - (h*nlon_out); + const int ho = row_idx[h]; + + // offset input tensors + kx += int64_t(batch)*nlat_in*nlon_in*nchans_in; + qy += int64_t(batch)*nlat_out*nlon_out*nchans_in + int64_t(ho)*nlon_out*nchans_in + int64_t(wo)*nchans_in; + + vx += int64_t(batch)*nlat_in*nlon_in*nchans_out; + dy += int64_t(batch)*nlat_out*nlon_out*nchans_out + int64_t(ho)*nlon_out*nchans_out + int64_t(wo)*nchans_out; + + // offset output tensors + dkx += int64_t(batch)*nlat_in*nlon_in*nchans_in; + dvx += int64_t(batch)*nlat_in*nlon_in*nchans_out; + dqy += int64_t(batch)*nlat_out*nlon_out*nchans_in + int64_t(ho)*nlon_out*nchans_in + int64_t(wo)*nchans_in; + + // zero/init shared memory + for (int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + sh_alpha_k__[chan] = __vset(0.0f); + sh_alpha_vw_[chan] = __vset(0.0f); + sh_alpha_kvw[chan] = __vset(0.0f); + + sh_qy[chan] = qy[chan]; + } + for (int chan = tidx; chan < nchans_out; chan += WARP_SIZE) { + sh_dy[chan] = dy[chan]; + } + +#if __CUDA_ARCH__ < 900 + // for architectures < 9.0, sh_dy and sh_qy will be read + // as individual floats at the end of the kernel, which + // breaks the assumption that each FLOATV_T location is + // written to and read by the same thread throughout the + // kernel, in the case FLOATV_T==float4 + if constexpr(std::is_same::value) { __syncwarp(); } +#endif + + // for dkx, dvx, dqy + float alpha_sum = 0.0f; + float qdotk_max = -FLT_MAX; + + // for dkx + float integral = 0.0f; + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + + const int rlen = rend - rbeg; + + // accumulate alpha_sum, integral, and shared stats, + // along with a progressively computed qdotk_max. + for (int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchans_in + int64_t(wip)*nchans_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchans_out + int64_t(wip)*nchans_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + + for(int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[chan], _kx[chan])); + } + for(int chan = tidx; chan < nchans_out; chan += WARP_SIZE) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + + const float qdotk = __warp_sum(__vred(qdotk_v)); + const float gdotv = __warp_sum(__vred(gdotv_v)); + + const float qdotk_max_tmp = max(qdotk_max, qdotk); + const float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; + const float max_correction = expf(qdotk_max - qdotk_max_tmp); + alpha_sum = alpha_sum * max_correction + alpha_inz; + + integral = integral * max_correction + alpha_inz * gdotv; + + const float ainz_gdotv = alpha_inz * gdotv; + + for (int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + + const FLOATV_T kxval = _kx[chan]; + + sh_alpha_k__[chan] = __vadd(__vscale(max_correction, sh_alpha_k__[chan]), __vscale(alpha_inz, kxval)); + sh_alpha_vw_[chan] = __vadd(__vscale(max_correction, sh_alpha_vw_[chan]), __vset(ainz_gdotv)); + sh_alpha_kvw[chan] = __vadd(__vscale(max_correction, sh_alpha_kvw[chan]), __vscale(ainz_gdotv, kxval)); + } + qdotk_max = qdotk_max_tmp; + } + + const float alpha_sum_inv = 1.0f / alpha_sum; + + integral *= alpha_sum_inv; + + // Write dqy + for (int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + + dqy[chan] = __vscale(alpha_sum_inv * alpha_sum_inv, + __vsub(__vscale(alpha_sum, sh_alpha_kvw[chan]), + __vmul(sh_alpha_vw_[chan], sh_alpha_k__[chan]))); + } + + // accumulate gradients for k and v + for (int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchans_in + int64_t(wip)*nchans_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchans_out + int64_t(wip)*nchans_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + + for (int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[chan], _kx[chan])); + } + for (int chan = tidx; chan < nchans_out; chan += WARP_SIZE) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + + const float qdotk = __warp_sum(__vred(qdotk_v)); + const float gdotv = __warp_sum(__vred(gdotv_v)); + + const float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; + + FLOATV_T *_dkx = dkx + int64_t(hi)*nlon_in*nchans_in + int64_t(wip)*nchans_in; + FLOATV_T *_dvx = dvx + int64_t(hi)*nlon_in*nchans_out + int64_t(wip)*nchans_out; + + const float alpha_mul = alpha_inz * alpha_sum_inv; + + const float scale_fact_qy = (gdotv - integral)*alpha_mul; + const float scale_fact_dy = alpha_mul; + + // float4, 128-bit atomics are only supported by devices of compute + // capability 9.x+, so on older devices we resort to 32-bit atomics + +#if __CUDA_ARCH__ < 1000 + // to use 32-bit operations on consecutve addresses + float *sh_qy_scl = reinterpret_cast(sh_qy); + float *sh_dy_scl = reinterpret_cast(sh_dy); + + float *_dkx_scl = reinterpret_cast(_dkx); + float *_dvx_scl = reinterpret_cast(_dvx); + + constexpr int VEC_SIZE = sizeof(FLOATV_T)/sizeof(float); + + // 32-bit, consecutive atomics to glmem; + // strided atomics results in a severe slowdown + for (int chan = tidx; chan < nchans_in*VEC_SIZE; chan += WARP_SIZE) { + atomicAdd(_dkx_scl + chan, scale_fact_qy * sh_qy_scl[chan]); + } + for (int chan = tidx; chan < nchans_out*VEC_SIZE; chan += WARP_SIZE) { + atomicAdd(_dvx_scl + chan, scale_fact_dy * sh_dy_scl[chan]); + } +#else + // 128-bit, consecutive atomics to glmem + for (int chan = tidx; chan < nchans_in; chan += WARP_SIZE) { + atomicAdd(_dkx + chan, __vscale(scale_fact_qy, sh_qy[chan])); + } + for (int chan = tidx; chan < nchans_out; chan += WARP_SIZE) { + atomicAdd(_dvx + chan, __vscale(scale_fact_dy, sh_dy[chan])); + } +#endif + } + + return; +} + +// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1) +template= nchan_in + typename FLOATV_T> // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_attn_bwd_special_vec_k(int nchan_in, // no. of FLOATV_T elements along channel dim + int nchan_out, // no. of FLOATV_T elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, // [batch][nlat_in][nlon_in][nchan_in] + const FLOATV_T *__restrict__ vx, // [batch][nlat_in][nlon_in][nchan_out] + const FLOATV_T *__restrict__ qy, // [batch][nlat_out][nlon_out][nchan_in] + const FLOATV_T *__restrict__ dy, // [batch][nlat_out][nlon_out][nchan_out] + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + FLOATV_T *__restrict__ dkx, // [batch][nlat_in][nlon_in][nchan_in] + FLOATV_T *__restrict__ dvx, // [batch][nlat_in][nlon_in][nchan_out] + FLOATV_T *__restrict__ dqy) { // [batch][nlat_out][nlon_out][nchan_in] + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X == 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int batch = blockIdx.y; + const uint64_t ctaid = uint64_t(blockIdx.x) * blockDim.y + threadIdx.y; + + if (ctaid >= uint64_t(nlat_out)*nlon_in) { + return; + } + + extern __shared__ __align__(sizeof(float4)) float shext[]; + + // sh_dy[nchan_out], sh_qy[nchan_in] + FLOATV_T *sh_dy = reinterpret_cast(shext) + threadIdx.y*(nchan_in+nchan_out);// + tidx; + FLOATV_T *sh_qy = sh_dy + nchan_out + tidx; + + if constexpr(CHOUT_AS_IN) { + sh_dy += tidx; + } + + // for dqy + FLOATV_T loc_k__[NLOC]; + FLOATV_T loc_vw_[NLOC]; + FLOATV_T loc_kvw[NLOC]; + + // use permuted rows + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + const int ho = row_idx[h]; + + // offset input tensors + kx += int64_t(batch)*nlat_in*nlon_in*nchan_in + tidx; + qy += int64_t(batch)*nlat_out*nlon_out*nchan_in + int64_t(ho)*nlon_out*nchan_in + int64_t(wo)*nchan_in + tidx; + + vx += int64_t(batch)*nlat_in*nlon_in*nchan_out;// + tidx; + dy += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out;// + tidx; + if (CHOUT_AS_IN) { + vx += tidx; + dy += tidx; + } + + // offset output tensors + dkx += int64_t(batch)*nlat_in*nlon_in*nchan_in + tidx; + dvx += int64_t(batch)*nlat_in*nlon_in*nchan_out;// + tidx; + if (CHOUT_AS_IN) { + dvx += tidx; + } + dqy += int64_t(batch)*nlat_out*nlon_out*nchan_in + int64_t(ho)*nlon_out*nchan_in + int64_t(wo)*nchan_in + tidx; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + loc_k__[i] = __vset(0.0f); + loc_vw_[i] = __vset(0.0f); + loc_kvw[i] = __vset(0.0f); + } + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + sh_qy[i*BDIM_X] = qy[i*BDIM_X]; + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + sh_qy[NLOC_M1*BDIM_X] = qy[NLOC_M1*BDIM_X]; + } + + if (CHOUT_AS_IN) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + sh_dy[i*BDIM_X] = dy[i*BDIM_X]; + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + sh_dy[NLOC_M1*BDIM_X] = dy[NLOC_M1*BDIM_X]; + } + } else { + for(int chan = tidx; chan < nchan_out; chan += BDIM_X) { + sh_dy[chan] = dy[chan]; + } + } + +#if __CUDA_ARCH__ < 900 + // for architectures < 9.0, sh_dy and sh_qy will be read + // as individual floats at the end of the kernel, which + // breaks the assumption that each FLOATV_T location is + // written to and read by the same thread throughout the + // kernel, in the case FLOATV_T==float4 + if constexpr(std::is_same::value) { + if constexpr(BDIM_X == 32) { __syncwarp(); } + else { __syncthreads(); } + } +#endif + + // for dkx, dvx, dqy + float alpha_sum = 0.0f; + float qdotk_max = -FLT_MAX; + + // for dkx + float integral = 0.0f; + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + + const int rlen = rend - rbeg; + + // accumulate alpha_sum, integral, and shared stats, + // along with a progressively computed qdotk_max. + for (int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan_out + int64_t(wip)*nchan_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[i*BDIM_X], _kx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[NLOC_M1*BDIM_X], _kx[NLOC_M1*BDIM_X])); + } + if constexpr(CHOUT_AS_IN) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[i*BDIM_X], _vx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[NLOC_M1*BDIM_X], _vx[NLOC_M1*BDIM_X])); + } + } else { + for(int chan = tidx; chan < nchan_out; chan += BDIM_X) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + } + + float qdotk = __vred(qdotk_v); + float gdotv = __vred(gdotv_v); + + if constexpr(BDIM_X == 32) { + qdotk = __warp_sum(qdotk); + gdotv = __warp_sum(gdotv); + } else { + qdotk = __block_sum(qdotk); + gdotv = __block_sum(gdotv); + } + + const float qdotk_max_tmp = max(qdotk_max, qdotk); + const float alpha_inz = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; + const float max_correction = expf(qdotk_max - qdotk_max_tmp); + + alpha_sum = alpha_sum * max_correction + alpha_inz; + integral = integral * max_correction + alpha_inz * gdotv; + + const float ainz_gdotv = alpha_inz * gdotv; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + const FLOATV_T kxval = _kx[i*BDIM_X]; + loc_k__[i] = __vadd(__vscale(max_correction, loc_k__[i]), __vscale(alpha_inz, kxval)); + loc_vw_[i] = __vadd(__vscale(max_correction, loc_vw_[i]), __vset(ainz_gdotv)); + loc_kvw[i] = __vadd(__vscale(max_correction, loc_kvw[i]), __vscale(ainz_gdotv, kxval)); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + const FLOATV_T kxval = _kx[NLOC_M1*BDIM_X]; + loc_k__[NLOC_M1] = __vadd(__vscale(max_correction, loc_k__[NLOC_M1]), __vscale(alpha_inz, kxval)); + loc_vw_[NLOC_M1] = __vadd(__vscale(max_correction, loc_vw_[NLOC_M1]), __vset(ainz_gdotv)); + loc_kvw[NLOC_M1] = __vadd(__vscale(max_correction, loc_kvw[NLOC_M1]), __vscale(ainz_gdotv, kxval)); + } + + qdotk_max = qdotk_max_tmp; + } + + const float alpha_sum_inv = 1.0f / alpha_sum; + + integral *= alpha_sum_inv; + + // Write dqy + const float alpha_sum_inv_sq = alpha_sum_inv*alpha_sum_inv; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + dqy[i*BDIM_X] = __vscale(alpha_sum_inv_sq, + __vsub(__vscale(alpha_sum, loc_kvw[i]), + __vmul(loc_vw_[i], loc_k__[i]))); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + dqy[NLOC_M1*BDIM_X] = __vscale(alpha_sum_inv_sq, + __vsub(__vscale(alpha_sum, loc_kvw[NLOC_M1]), + __vmul(loc_vw_[NLOC_M1], loc_k__[NLOC_M1]))); + } + + // accumulate gradients for k and v + for (int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi * nlon_in); + const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan_out + int64_t(wip)*nchan_out; + + FLOATV_T qdotk_v = __vset(0.0f); + FLOATV_T gdotv_v = __vset(0.0f); + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[i*BDIM_X], _kx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + qdotk_v = __vadd(qdotk_v, __vmul(sh_qy[NLOC_M1*BDIM_X], _kx[NLOC_M1*BDIM_X])); + } + if constexpr(CHOUT_AS_IN) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[i*BDIM_X], _vx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[NLOC_M1*BDIM_X], _vx[NLOC_M1*BDIM_X])); + } + } else { + for(int chan = tidx; chan < nchan_out; chan += BDIM_X) { + gdotv_v = __vadd(gdotv_v, __vmul(sh_dy[chan], _vx[chan])); + } + } + + float qdotk = __vred(qdotk_v); + float gdotv = __vred(gdotv_v); + + if constexpr(BDIM_X == 32) { + qdotk = __warp_sum(qdotk); + gdotv = __warp_sum(gdotv); + } else { + qdotk = __block_sum(qdotk); + gdotv = __block_sum(gdotv); + } + + const float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi]; + + FLOATV_T *_dkx = dkx + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + FLOATV_T *_dvx = dvx + int64_t(hi)*nlon_in*nchan_out + int64_t(wip)*nchan_out; + + const float alpha_mul = alpha_inz * alpha_sum_inv; + + const float scale_fact_qy = (gdotv - integral)*alpha_mul; + const float scale_fact_dy = alpha_mul; + + // float4, 128-bit atomics are only supported by devices of compute + // capability 9.x+, so on older devices we resort to 32-bit atomics + +#if __CUDA_ARCH__ < 1000 + constexpr int VEC_SIZE = sizeof(FLOATV_T)/sizeof(float); + + // making the loop count known at compile time doesn't seem + // to make any difference here so let's keep this (much) + // simpler version + float *sh_qy_scl = reinterpret_cast(sh_qy); + float *sh_dy_scl = reinterpret_cast(sh_dy); + + float *_dkx_scl = reinterpret_cast(_dkx); + float *_dvx_scl = reinterpret_cast(_dvx); + + sh_qy_scl -= tidx*VEC_SIZE; + _dkx_scl -= tidx*VEC_SIZE; + if constexpr(CHOUT_AS_IN) { + sh_dy_scl -= tidx*VEC_SIZE; + _dvx_scl -= tidx*VEC_SIZE; + } + + // 32-bit, consecutive atomics to glmem + // strided atomics results in a severe slowdown + for (int chan = tidx; chan < nchan_in*VEC_SIZE; chan += BDIM_X) { + atomicAdd(_dkx_scl + chan, scale_fact_qy * sh_qy_scl[chan]); + } + for (int chan = tidx; chan < nchan_out*VEC_SIZE; chan += BDIM_X) { + atomicAdd(_dvx_scl + chan, scale_fact_dy * sh_dy_scl[chan]); + } +#else + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + atomicAdd(_dkx + i*BDIM_X, __vscale(scale_fact_qy, sh_qy[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + atomicAdd(_dkx + NLOC_M1*BDIM_X, __vscale(scale_fact_qy, sh_qy[NLOC_M1*BDIM_X])); + } + if constexpr(CHOUT_AS_IN) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + atomicAdd(_dvx + i*BDIM_X, __vscale(scale_fact_dy, sh_dy[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + atomicAdd(_dvx + NLOC_M1*BDIM_X, __vscale(scale_fact_dy, sh_dy[NLOC_M1*BDIM_X])); + } + } else { + for (int chan = tidx; chan < nchan_out; chan += BDIM_X) { + atomicAdd(_dvx + chan, __vscale(scale_fact_dy, sh_dy[chan])); + } + } +#endif + } + + return; +} + +template +void launch_gen_attn_bwd(int batch_size, + int nchans_in, + int nchans_out, + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + FLOATV_T *_kxp, + FLOATV_T *_vxp, + FLOATV_T *_qyp, + FLOATV_T *_dyp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + float *_quad_weights, + FLOATV_T *_dkxp, + FLOATV_T *_dvxp, + FLOATV_T *_dqyp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS / WARP_SIZE); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + size_t shsize = sizeof(FLOATV_T)*(nchans_in*4+nchans_out) * block.y; // 5 arrays per warp + + s2_attn_bwd_generic_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, + _quad_weights, _dkxp, _dvxp, _dqyp); + CHECK_ERROR("s2_attn_bwd_generic_vec_k"); + + return; +} + +template +void launch_spc_attn_bwd(int nloc, // "BDIM_X*nloc" >= nchans_out + int batch_size, + int nchans_in, + int nchans_out, + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + FLOATV_T *_kxp, + FLOATV_T *_vxp, + FLOATV_T *_qyp, + FLOATV_T *_dyp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + float *_quad_weights, + FLOATV_T *_dkxp, + FLOATV_T *_dvxp, + FLOATV_T *_dqyp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + size_t shsize = sizeof(FLOATV_T)*(nchans_in+nchans_out) * block.y; // 2 arrays per cta, block.y > 1 iif block.x==32 + + // nloc determines the size of local arrays used to store + // temporary buffers loc_k__[], loc_vw_[] and loc_kvw[], + // of size nchans_in each; + // if nchans_out is >= BDIM_X*(nloc-1) and <= BDIM_X*nloc + // then we can use the same compile-time known loops used + // for input channels, with the execpetion of testing + // whether to execute the last iteration based on "nchans_out" + // ibstead of "nchans_in"; in this way as long as the + // difference between the number of input and output channels + // is <= BDIM_X we can use the faster path + if (nchans_out >= BDIM_X*(CUR_LOC_SIZE-1) && + nchans_out <= BDIM_X* CUR_LOC_SIZE ) { + s2_attn_bwd_special_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, + _quad_weights, _dkxp, _dvxp, _dqyp); + } else { + s2_attn_bwd_special_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, + _quad_weights, _dkxp, _dvxp, _dqyp); + + } + CHECK_ERROR("s2_attn_bwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_attn_bwd(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, + _dkxp, _dvxp, _dqyp, stream); + } + return; +} + +static void s2_attn_bwd_dispatch(int64_t batch_size, + int64_t nchans_in, + int64_t nchans_out, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + at::Tensor kxP, + at::Tensor vxP, + at::Tensor qyP, + at::Tensor dyP, + at::Tensor row_off, + at::Tensor col_idx, + at::Tensor quad_weights, + at::Tensor dkxP, + at::Tensor dvxP, + at::Tensor dqyP) { + + static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_idx = sortRows(nlat_out, row_off, stream); + + const int nlat_in = kxP.size(1); + + // smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans_in + int bdimx; + bdimx = DIV_UP(nchans_in, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE); + bdimx = next_pow2(bdimx); + + float *_kxp = reinterpret_cast(kxP.data_ptr()); + float *_vxp = reinterpret_cast(vxP.data_ptr()); + float *_qyp = reinterpret_cast(qyP.data_ptr()); + float *_dyp = reinterpret_cast(dyP.data_ptr()); + + float *_dkxp = reinterpret_cast(dkxP.data_ptr()); + float *_dvxp = reinterpret_cast(dvxP.data_ptr()); + float *_dqyp = reinterpret_cast(dqyP.data_ptr()); + + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_quad_weights = reinterpret_cast(quad_weights.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_kxp) || + !is_aligned(_vxp) || + !is_aligned(_qyp) || + !is_aligned(_dyp) || + !is_aligned(_dkxp) || + !is_aligned(_dvxp) || + !is_aligned(_dqyp) || + (nchans_in % VEC_SIZE) != 0 || + (nchans_out % VEC_SIZE) != 0) { + + const int nloc = DIV_UP(nchans_in, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans_in, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans_in <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough; w.r.t fowrard, + // we use the special kernel only up to BDIM_X=512 as with 1024 + // each thread cannot use more than 64 registers, resulting in + // large amounts of registers spills + switch(bdimx) { + case 32: launch_spc_attn_bwd< 32, 2, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + case 64: launch_spc_attn_bwd< 64, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + case 128: launch_spc_attn_bwd<128, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + case 256: launch_spc_attn_bwd<256, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + case 512: launch_spc_attn_bwd<512, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + default: launch_gen_attn_bwd ( batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _dyp, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp, _dvxp, _dqyp, stream); break; + } + + } else { + + float4 *_kxp4 = reinterpret_cast(kxP.data_ptr()); + float4 *_vxp4 = reinterpret_cast(vxP.data_ptr()); + float4 *_qyp4 = reinterpret_cast(qyP.data_ptr()); + float4 *_dyp4 = reinterpret_cast(dyP.data_ptr()); + + float4 *_dkxp4 = reinterpret_cast(dkxP.data_ptr()); + float4 *_dvxp4 = reinterpret_cast(dvxP.data_ptr()); + float4 *_dqyp4 = reinterpret_cast(dqyP.data_ptr()); + + nchans_in /= VEC_SIZE; + nchans_out /= VEC_SIZE; + const int nloc = DIV_UP(nchans_in, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + + constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 32: launch_spc_attn_bwd< 32, 2, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + case 64: launch_spc_attn_bwd< 64, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + case 128: launch_spc_attn_bwd<128, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + case 256: launch_spc_attn_bwd<256, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + case 512: launch_spc_attn_bwd<512, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + default: launch_gen_attn_bwd ( batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _dyp4, _row_idx, _row_off, _col_idx, _quad_weights, _dkxp4, _dvxp4, _dqyp4, stream); break; + } + } + + return; +} + +// END backward kernels and functions + +std::tuple s2_attention_bwd_dkvq_cuda(at::Tensor kx, at::Tensor vx, at::Tensor qy, + at::Tensor dy, at::Tensor quad_weights, + at::Tensor psi_col_idx, at::Tensor psi_row_off, + int64_t nlon_in, int64_t nlat_out, int64_t nlon_out) +{ + + CHECK_CUDA_INPUT_TENSOR(kx); + CHECK_CUDA_INPUT_TENSOR(vx); + CHECK_CUDA_INPUT_TENSOR(qy); + CHECK_CUDA_INPUT_TENSOR(dy); + CHECK_CUDA_TENSOR(quad_weights); + CHECK_CUDA_TENSOR(psi_col_idx); + CHECK_CUDA_TENSOR(psi_row_off); + + //const size_t uo_num_channels = kx.size(1); + size_t nchans_in = qy.size(1); // or kx.size(1) + size_t nchans_out = vx.size(1); + + const int batch_size = kx.size(0); + + // extract dtype + auto kx_type = kx.dtype(); // nchans_in + auto qy_type = qy.dtype(); + auto vx_type = vx.dtype(); // ncahn_out + auto dy_type = dy.dtype(); + + torch::Tensor kxP = kx.to(torch::kFloat32); + torch::Tensor vxP = vx.to(torch::kFloat32); + torch::Tensor qyP = qy.to(torch::kFloat32); + torch::Tensor dyP = dy.to(torch::kFloat32); + + // exract memory format: this is much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) + // the former fails for num_channels == 1 + bool kx_is_channels_last = kxP.strides()[1] == 1; + bool vx_is_channels_last = vxP.strides()[1] == 1; + bool qy_is_channels_last = qyP.strides()[1] == 1; + bool dy_is_channels_last = dyP.strides()[1] == 1; + + // transpose if required + if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } + if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } + if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } + if (!dy_is_channels_last) { dyP = permute_4D_to0231(dyP); } + + torch::Tensor dkxP = torch::zeros_like(kxP); + torch::Tensor dvxP = torch::zeros_like(vxP); + torch::Tensor dqyP = torch::zeros_like(qyP); + + s2_attn_bwd_dispatch(batch_size, + nchans_in, + nchans_out, + nlon_in, + nlat_out, + nlon_out, + kxP, vxP, qyP, dyP, + psi_row_off, + psi_col_idx, + quad_weights, + dkxP, dvxP, dqyP); + + torch::Tensor dkx = dkxP; + torch::Tensor dvx = dvxP; + torch::Tensor dqy = dqyP; + + if (!kx_is_channels_last) { dkx = permute_4D_to0312(dkx); } + if (!vx_is_channels_last) { dvx = permute_4D_to0312(dvx); } + if (!qy_is_channels_last) { dqy = permute_4D_to0312(dqy); } + + // convert precision back to starting + dkx = dkx.to(kx_type); + dvx = dvx.to(vx_type); + dqy = dqy.to(qy_type); + + return std::make_tuple(dkx, dvx, dqy); +} + diff --git a/torch_harmonics_attn/attention_cuda_fwd.cu b/torch_harmonics_attn/attention_cuda_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..a975a0a23373b814144b444477f0b68cbfb14ef4 --- /dev/null +++ b/torch_harmonics_attn/attention_cuda_fwd.cu @@ -0,0 +1,577 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "attention_cuda.cuh" +#include +#include +#include +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "attention_cuda_utils.cuh" + +#define THREADS (64) + +#define MAX_LOCAL_ARR_LEN (16) + + +// called with (blockDim.x=32 and blockDim.y>1, BDIM_X=blockDim.x*blockDim.y) +template // either float or float4 +__global__ +__launch_bounds__(BDIM_X) +void s2_attn_fwd_generic_vec_k(int nchan_in, // no. of FLOATV_T elements along channel dim + int nchan_out, // no. of FLOATV_T elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, + const FLOATV_T *__restrict__ vx, + const FLOATV_T *__restrict__ qy, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + FLOATV_T *__restrict__ y) { + + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shy = reinterpret_cast(shext) + threadIdx.y*nchan_out; + + const int batch = blockIdx.y; + const int wid = blockIdx.x*blockDim.y + threadIdx.y; + + if (wid >= nlat_out*nlon_out) { + return; + } + + const int tidx = threadIdx.x; + + const int h = wid / nlon_out; + const int wo = wid - (h*nlon_out); + const int ho = row_idx[h]; + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = __vset(0.f); + } + + kx += int64_t(batch)*nlat_in*nlon_in*nchan_in; + qy += int64_t(batch)*nlat_out*nlon_out*nchan_in + int64_t(ho)*nchan_in*nlon_out + int64_t(wo)*nchan_in; + + vx += int64_t(batch)*nlat_in*nlon_in*nchan_out; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nchan_out*nlon_out + int64_t(wo)*nchan_out; + + float alpha_sum = 0.0f; + float qdotk_max = -FLT_MAX; + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + + const int rlen = rend-rbeg; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan_out + int64_t(wip)*nchan_out; + + FLOATV_T qdotkv = __vset(0.f); + + for(int chan = tidx; chan < nchan_in; chan += WARP_SIZE) { + qdotkv = __vadd(qdotkv, + __vmul( qy[chan], + _kx[chan])); + } + + float qdotk = __warp_sum(__vred(qdotkv)); + + float qdotk_max_tmp; + float alpha; + float exp_save; + + qdotk_max_tmp = max(qdotk_max, qdotk); + alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; + exp_save = expf(qdotk_max - qdotk_max_tmp); + + alpha_sum = alpha + alpha_sum*exp_save; + + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + shy[chan] = __vadd(__vscale(exp_save, shy[chan]), + __vscale( alpha, _vx[chan])); + } + qdotk_max = qdotk_max_tmp; + } + + alpha_sum = 1.0f / alpha_sum; + for(int chan = tidx; chan < nchan_out; chan += WARP_SIZE) { + y[chan] = __vscale(alpha_sum, shy[chan]); + } + + return; +} + +// called with either (BDIM_X=32 and BDIM_Y>1) || (2^K=BDIM_X > 32 and BDIM_Y=1) +template= nchan_out + typename FLOATV_T> // either float or float4 +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void s2_attn_fwd_special_vec_k(int nchan_in, // no. of FLOATV_T elements along channel dim + int nchan_out, // no. of FLOATV_T elements along channel dim + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + const FLOATV_T *__restrict__ kx, + const FLOATV_T *__restrict__ vx, + const FLOATV_T *__restrict__ qy, + const int32_t *__restrict__ row_idx, + const int64_t *__restrict__ row_off, + const int64_t *__restrict__ col_idx, + const float *__restrict__ quad_weights, + FLOATV_T *__restrict__ y) { + + static_assert(0 == (BDIM_X & (BDIM_X-1))); + static_assert(0 == (BDIM_Y & (BDIM_Y-1))); + static_assert((BDIM_X == 32 && BDIM_Y > 1) || + (BDIM_X > 32 && BDIM_Y == 1)) ; + + constexpr int NLOC_M1 = NLOC-1; + + const int tidx = threadIdx.x; + const int batch = blockIdx.y; + const int ctaid = blockIdx.x*blockDim.y + threadIdx.y; + + if (ctaid >= nlat_out*nlon_out) { + return; + } + + FLOATV_T locy[NLOC]; + + extern __shared__ __align__(sizeof(float4)) float shext[]; + FLOATV_T *shq = reinterpret_cast(shext) + threadIdx.y*nchan_in; + + if constexpr(CHIN_AS_OUT) { + shq += tidx; + } + + const int h = ctaid / nlon_out; + const int wo = ctaid - (h*nlon_out); + const int ho = row_idx[h]; + + kx += int64_t(batch)*nlat_in*nlon_in*nchan_in; + qy += int64_t(batch)*nlat_out*nlon_out*nchan_in + int64_t(ho)*nlon_out*nchan_in + int64_t(wo)*nchan_in; + if constexpr(CHIN_AS_OUT) { + kx += tidx; + qy += tidx; + } + + vx += int64_t(batch)*nlat_in*nlon_in*nchan_out + tidx; + y += int64_t(batch)*nlat_out*nlon_out*nchan_out + int64_t(ho)*nlon_out*nchan_out + int64_t(wo)*nchan_out + tidx; + + #pragma unroll + for(int i = 0; i < NLOC; i++) { + locy[i] = __vset(0.f); + } + + if constexpr(CHIN_AS_OUT) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + shq[i*BDIM_X] = qy[i*BDIM_X]; + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + shq[NLOC_M1*BDIM_X] = qy[NLOC_M1*BDIM_X]; + } + } else { + for(int chan = tidx; chan < nchan_in; chan += BDIM_X) { + shq[chan] = qy[chan]; + } + } + + float alpha_sum = 0.0f; + float qdotk_max = -FLT_MAX; + + const int64_t rbeg = row_off[ho]; + const int64_t rend = row_off[ho+1]; + + col_idx += rbeg; + + const int rlen = rend-rbeg; + + for(int off = 0; off < rlen; off++) { + + const int64_t col = col_idx[off]; + + const int hi = col / nlon_in; + const int wi = col - (hi*nlon_in); + const int wip = (wi+wo) - ((wi+wo) / nlon_in) * nlon_in; + + const FLOATV_T *_kx = kx + int64_t(hi)*nlon_in*nchan_in + int64_t(wip)*nchan_in; + const FLOATV_T *_vx = vx + int64_t(hi)*nlon_in*nchan_out + int64_t(wip)*nchan_out; + + FLOATV_T qdotkv = __vset(0.f); + + if constexpr(CHIN_AS_OUT) { + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + qdotkv = __vadd(qdotkv, + __vmul(shq[i*BDIM_X], + _kx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_in) { + qdotkv = __vadd(qdotkv, + __vmul(shq[NLOC_M1*BDIM_X], + _kx[NLOC_M1*BDIM_X])); + } + } else { + for(int chan = tidx; chan < nchan_in; chan += BDIM_X) { + qdotkv = __vadd(qdotkv, __vmul(shq[chan], _kx[chan])); + } + } + + float qdotk = __vred(qdotkv); + if constexpr(BDIM_X == 32) { qdotk = __warp_sum(qdotk); } + else { qdotk = __block_sum(qdotk); } + + float qdotk_max_tmp; + float alpha; + float exp_save; + + qdotk_max_tmp = max(qdotk_max, qdotk); + alpha = expf(qdotk - qdotk_max_tmp) * quad_weights[hi]; + exp_save = expf(qdotk_max - qdotk_max_tmp); + + alpha_sum = alpha + alpha_sum*exp_save; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + locy[i] = __vadd(__vscale(exp_save, locy[i]), + __vscale(alpha, _vx[i*BDIM_X])); + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + locy[NLOC_M1] = __vadd(__vscale(exp_save, locy[NLOC_M1]), + __vscale(alpha, _vx[NLOC_M1*BDIM_X])); + } + + qdotk_max = qdotk_max_tmp; + } + + alpha_sum = 1.0f / alpha_sum; + + #pragma unroll + for(int i = 0; i < NLOC_M1; i++) { + y[i*BDIM_X] = __vscale(alpha_sum, locy[i]); + } + if (NLOC_M1*BDIM_X+tidx < nchan_out) { + y[NLOC_M1*BDIM_X] = __vscale(alpha_sum, locy[NLOC_M1]); + } + + return; +} + +template +void launch_gen_attn_fwd(int batch_size, + int nchans_in, + int nchans_out, + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + FLOATV_T *__restrict__ _kxp, + FLOATV_T *__restrict__ _vxp, + FLOATV_T *__restrict__ _qyp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + float *_quad_weights, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + dim3 block(WARP_SIZE, THREADS/WARP_SIZE); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + size_t shsize = sizeof(FLOATV_T)*nchans_out * block.y; + + s2_attn_fwd_generic_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp); + CHECK_ERROR("s2_attn_fwd_generic_vec_k"); + + return; +} + +template +void launch_spc_attn_fwd(int nloc, // "BDIM_X*nloc" >= nchans_out + int batch_size, + int nchans_in, + int nchans_out, + int nlat_in, + int nlon_in, + int nlat_out, + int nlon_out, + FLOATV_T *__restrict__ _kxp, + FLOATV_T *__restrict__ _vxp, + FLOATV_T *__restrict__ _qyp, + int32_t *_row_idx, + int64_t *_row_off, + int64_t *_col_idx, + float *_quad_weights, + FLOATV_T *__restrict__ _yp, + cudaStream_t stream) { + + if (CUR_LOC_SIZE == nloc) { + + dim3 block(BDIM_X, BDIM_Y); + dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size); + + //size_t shsize = sizeof(FLOATV_T)*nchans_out * block.y; // block.y > 1 iif block.x==32 + size_t shsize = sizeof(FLOATV_T)*nchans_in * block.y; // block.y > 1 iif block.x==32 + + // nloc determines the size of local arrays used to store + // y vectors, of length nchans_out; + // if nchans_in is >= BDIM_X*(nloc-1) and <= BDIM_X*nloc + // then we can use the same compile-time known loops used + // for output channels, with the execpetion of testing + // whether to execute the last iteration based on "nchans_in" + // rather than on "nchans_out"; in this way as long as the + // difference between the number of input and output channels + // is <= BDIM_X we can use the faster path + if (nchans_in >= BDIM_X*(CUR_LOC_SIZE-1) && + nchans_in <= BDIM_X* CUR_LOC_SIZE ) { + + s2_attn_fwd_special_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp); + } else { + + s2_attn_fwd_special_vec_k + <<>>(nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp); + } + CHECK_ERROR("s2_attn_fwd_special_vec_k"); + + return; + } + if constexpr(CUR_LOC_SIZE < MAX_LOC_SIZE) { + launch_spc_attn_fwd(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, + _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, + stream); + } + return; +} + +static void s2_attn_fwd_dispatch(int64_t batch_size, + int64_t nchans_in, + int64_t nchans_out, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out, + at::Tensor kxP, + at::Tensor vxP, + at::Tensor qyP, + at::Tensor row_off, + at::Tensor col_idx, + at::Tensor quad_weights, + at::Tensor yP) { + + static_assert(0 == (MAX_LOCAL_ARR_LEN & (MAX_LOCAL_ARR_LEN-1))); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // sort row indices (ho-s) in descending order + // based on (row_off[ho+1]-row_off[ho]) + at::Tensor row_idx = sortRows(nlat_out, row_off, stream); + + const int nlat_in = kxP.size(1); + + // smallest power of two "bdimx" (>=32) s.t. bdimx*MAX_LOCAL_ARR_LEN >= nchans_out + int bdimx; + bdimx = DIV_UP(nchans_out, MAX_LOCAL_ARR_LEN); + bdimx = max(bdimx, WARP_SIZE); + bdimx = next_pow2(bdimx); + + float *_kxp = reinterpret_cast(kxP.data_ptr()); + float *_vxp = reinterpret_cast(vxP.data_ptr()); + float *_qyp = reinterpret_cast(qyP.data_ptr()); + float *_yp = reinterpret_cast(yP.data_ptr()); + + int32_t *_row_idx = reinterpret_cast(row_idx.data_ptr()); + int64_t *_row_off = reinterpret_cast(row_off.data_ptr()); + int64_t *_col_idx = reinterpret_cast(col_idx.data_ptr()); + float *_quad_weights = reinterpret_cast(quad_weights.data_ptr()); + + constexpr int VEC_SIZE = sizeof(float4) / sizeof(float); + + if (!is_aligned(_kxp) || + !is_aligned(_vxp) || + !is_aligned(_qyp) || + !is_aligned(_yp) || + (nchans_in % VEC_SIZE) != 0 || + (nchans_out % VEC_SIZE) != 0) { + + const int nloc = DIV_UP(nchans_out, bdimx); + + // to avoid the compilation of unused template instances; + // we use a block size BDIM_X that is the smallest power of 2 + // such that BDIM_X*MAX_LOCAL_ARR_LEN >= nchans_out, so + // BDIM_X > 32 are used only for: + // + // (BDIM_X-1)*MAX_LOCAL_ARR_LEN < nchans_out <= BDIM_X*MAX_LOCAL_ARR_LEN + constexpr int MIN_LOC_ARR_LEN = MAX_LOCAL_ARR_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 32: launch_spc_attn_fwd< 32, 2, 1, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + case 64: launch_spc_attn_fwd< 64, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + case 128: launch_spc_attn_fwd< 128, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + case 256: launch_spc_attn_fwd< 256, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + case 512: launch_spc_attn_fwd< 512, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + case 1024: launch_spc_attn_fwd<1024, 1, MIN_LOC_ARR_LEN, MAX_LOCAL_ARR_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + default: launch_gen_attn_fwd ( batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp, _vxp, _qyp, _row_idx, _row_off, _col_idx, _quad_weights, _yp, stream); break; + } + + } else { + + float4 *_kxp4 = reinterpret_cast(_kxp); + float4 *_vxp4 = reinterpret_cast(_vxp); + float4 *_qyp4 = reinterpret_cast(_qyp); + float4 *_yp4 = reinterpret_cast(_yp); + + nchans_in /= VEC_SIZE; + nchans_out /= VEC_SIZE; + const int nloc = DIV_UP(nchans_out, bdimx); + + constexpr int MAX_LOCAL_VEC_LEN = MAX_LOCAL_ARR_LEN / VEC_SIZE; + + constexpr int MIN_LOC_VEC_LEN = MAX_LOCAL_VEC_LEN/2+1; + + // use 2D blocks only if 32 threads are enough + switch(bdimx) { + case 32: launch_spc_attn_fwd< 32, 2, 1, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + case 64: launch_spc_attn_fwd< 64, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + case 128: launch_spc_attn_fwd< 128, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + case 256: launch_spc_attn_fwd< 256, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + case 512: launch_spc_attn_fwd< 512, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + case 1024: launch_spc_attn_fwd<1024, 1, MIN_LOC_VEC_LEN, MAX_LOCAL_VEC_LEN>(nloc, batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + default: launch_gen_attn_fwd ( batch_size, nchans_in, nchans_out, nlat_in, nlon_in, nlat_out, nlon_out, _kxp4, _vxp4, _qyp4, _row_idx, _row_off, _col_idx, _quad_weights, _yp4, stream); break; + } + } + + return; +} + +// END - forward kernels and functions + +torch::Tensor s2_attention_fwd_cuda(at::Tensor kx, + at::Tensor vx, + at::Tensor qy, + at::Tensor quad_weights, + at::Tensor psi_col_idx, + at::Tensor psi_row_off, + int64_t nlon_in, + int64_t nlat_out, + int64_t nlon_out) { + CHECK_CUDA_INPUT_TENSOR(kx); + CHECK_CUDA_INPUT_TENSOR(vx); + CHECK_CUDA_INPUT_TENSOR(qy); + CHECK_CUDA_TENSOR(quad_weights); + CHECK_CUDA_TENSOR(psi_col_idx); + CHECK_CUDA_TENSOR(psi_row_off); + + size_t nchans_in = qy.size(1); // or kx.size(1) + size_t nchans_out = vx.size(1); + + const int batch_size = kx.size(0); + + // extract dtype + auto qy_type = qy.dtype(); + + torch::Tensor kxP = kx.to(torch::kFloat32); + torch::Tensor vxP = vx.to(torch::kFloat32); + torch::Tensor qyP = qy.to(torch::kFloat32); + + // these are much safer than checking is_contiguous(at::MemoryFormat::ChannelsLast) + // the former fails for num_channels == 1 + bool kx_is_channels_last = kxP.strides()[1] == 1; + bool vx_is_channels_last = vxP.strides()[1] == 1; + bool qy_is_channels_last = qyP.strides()[1] == 1; + + if (!kx_is_channels_last) { kxP = permute_4D_to0231(kxP); } + if (!vx_is_channels_last) { vxP = permute_4D_to0231(vxP); } + if (!qy_is_channels_last) { qyP = permute_4D_to0231(qyP); } + + torch::Tensor yP = torch::empty_like(vxP); + + s2_attn_fwd_dispatch(batch_size, + nchans_in, + nchans_out, + nlon_in, + nlat_out, + nlon_out, + kxP, vxP, qyP, + psi_row_off, + psi_col_idx, + quad_weights, + yP); + + torch::Tensor y = yP; + if (!qy_is_channels_last) { y = permute_4D_to0312(y); } + + // convert precision back to starting + y = y.to(qy_type); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return y; +} + + diff --git a/torch_harmonics_attn/attention_cuda_utils.cu b/torch_harmonics_attn/attention_cuda_utils.cu new file mode 100644 index 0000000000000000000000000000000000000000..f48f6aabd4f2be63d18889fa2f636434fc25e39f --- /dev/null +++ b/torch_harmonics_attn/attention_cuda_utils.cu @@ -0,0 +1,180 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "attention_cuda_utils.cuh" + +#include +#include +#include + +#include + +#include +#include + +#include "cudamacro.h" +#include "attention_cuda.cuh" + +#define THREADS (64) + +#define TRANSP_WARPS_X_TILE_GENERIC (32) +#define TRANSP_WARPS_X_TILE_SM100 (4) + +// BEGIN - CSR rows sorting kernels and functions +__global__ void set_rlen_rids_k(const int n, + const int64_t *__restrict__ offs, + int *__restrict__ rids, + int *__restrict__ rlen) { + + const int nth = gridDim.x*blockDim.x; + const int tid = blockIdx.x*blockDim.x + threadIdx.x; + + for(int i = tid; i < n; i += nth) { + rids[i] = i; + rlen[i] = offs[i+1]-offs[i]; + } + + return; +} + +at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream) { + + int64_t *_row_off_d = reinterpret_cast(row_off.data_ptr()); + + auto options = torch::TensorOptions().dtype(torch::kInt32).device(row_off.device()); + + torch::Tensor rids_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_d = torch::empty({nlat_out}, options); + + int *_rids_d = reinterpret_cast(rids_d.data_ptr()); + int *_rlen_d = reinterpret_cast(rlen_d.data_ptr()); + + const int grid = DIV_UP(nlat_out, THREADS); + const int block = THREADS; + + set_rlen_rids_k<<>>(nlat_out, + _row_off_d, + _rids_d, + _rlen_d); + + torch::Tensor rids_sort_d = torch::empty({nlat_out}, options); + torch::Tensor rlen_sort_d = torch::empty({nlat_out}, options); + + int *_rids_sort_d = reinterpret_cast(rids_sort_d.data_ptr()); + int *_rlen_sort_d = reinterpret_cast(rlen_sort_d.data_ptr()); + + size_t temp_storage_bytes = 0; + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(NULL, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + + options = torch::TensorOptions().dtype(torch::kByte).device(row_off.device()); + torch::Tensor temp_storage_d = torch::empty({int64_t(temp_storage_bytes)}, options); + + void *_temp_storage_d = reinterpret_cast(temp_storage_d.data_ptr()); + + CHECK_CUDA(cub::DeviceRadixSort::SortPairsDescending(_temp_storage_d, temp_storage_bytes, + _rlen_d, _rlen_sort_d, + _rids_d, _rids_sort_d, + nlat_out, 0, sizeof(*_rlen_d)*8, stream)); + return rids_sort_d; +} +// END - CSR rows sorting kernels and functions + + +// BEGIN - 4D tensor permutation kernels and functions +__global__ void empty_k() {} + +static int getPtxver() { + cudaFuncAttributes attrs; + CHECK_CUDA(cudaFuncGetAttributes(&attrs, empty_k)); + return attrs.ptxVersion*10; +} + +at::Tensor permute_4D_to0231(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(2), src.size(3), src.size(1)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_generic", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0231_k_tile_sm100", ([&] { + launch_permute_to0231(src, dst); + })); + CHECK_ERROR("permute_to0231_k_tile_sm100"); + } + + return dst; +} + +at::Tensor permute_4D_to0312(at::Tensor src) { + + auto options = torch::TensorOptions().dtype(src.dtype()).device(src.device()); + torch::Tensor dst = torch::empty({src.size(0), src.size(3), src.size(1), src.size(2)}, options); + + const int ptxv = getPtxver(); + + // to be further specialized for additional archs, if necessary + if (ptxv < 100) { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_generic", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_generic"); + } else { + AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "permute_to0312_k_tile_sm100", ([&] { + launch_permute_to0312(src, dst); + })); + CHECK_ERROR("permute_to0312_k_tile_sm100"); + } + + return dst; +} +// END - tensor permutation kernels and functions + +// BEGIN - general host-side functions +unsigned int next_pow2(unsigned int x) { + + x -= 1; + + #pragma unroll + for(int i = 1; i <= sizeof(x)*8 / 2; i *= 2) { + x |= x >> i; + } + return x+1; +} +// END - general host-side functions diff --git a/torch_harmonics_attn/attention_cuda_utils.cuh b/torch_harmonics_attn/attention_cuda_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..34b342f7fd0adb4743a5fe579c64443d6b249cf2 --- /dev/null +++ b/torch_harmonics_attn/attention_cuda_utils.cuh @@ -0,0 +1,376 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include + +#define WARP_SIZE (32) +#define FULL_MASK (0xFFFFFFFF) +#define DIV_UP(a,b) (((a)+((b)-1))/(b)) + + +// CSR rows sorting kernels and functions +at::Tensor sortRows(int nlat_out, at::Tensor row_off, cudaStream_t stream); + +// 4D tensor permutation kernels and functions +at::Tensor permute_4D_to0231(at::Tensor src); +at::Tensor permute_4D_to0312(at::Tensor src); + +// Host tensor dump and CSR manipulation functions +void dump_tensor(const char *fname, at::Tensor t); +void dump_csr(const char *fname, at::Tensor roff, at::Tensor cols); + +int part_csr_rows(int *row_perm, + const at::Tensor roff, + const at::Tensor cols, + int **part_off, + int **part_val); + +int verify_part(const int npart, + const int *part_off, + const int *part_val, + const at::Tensor roff, + const at::Tensor cols); + +void verify_part_new(const int nlon_out, + const int nlat_in, + const int nlon_in, + const int npart, // partitioning data + const int *part_off, + const int *part_val, + const at::Tensor roff, + const at::Tensor cols); + +unsigned int next_pow2(unsigned int x); + + +// utility host functions and templates + +template +int is_aligned(const void *ptr) { + + static_assert(0 == (ALIGN & (ALIGN-1))); + return (0 == (uintptr_t(ptr) & (ALIGN-1))); +} + + +// utility device functions and templates + +template +__device__ FLOATV_T __vset(float x) { + static_assert(sizeof(FLOATV_T) == 0, "Unsupported type for __vset"); + return FLOATV_T{}; +} + +template<> +__device__ float __forceinline__ __vset(float x) { + return x; +} + +__device__ float __forceinline__ __vmul(float a, float b) { + return a*b; +} + +__device__ float __forceinline__ __vadd(float a, float b) { + return a+b; +} + +__device__ float __forceinline__ __vsub(float a, float b) { + return a-b; +} + +__device__ float __forceinline__ __vred(float a) { + return a; +} + +__device__ float __forceinline__ __vscale(float s, float v) { + return v*s; +} + +__device__ float __forceinline__ __vdiv(float s, float v) { + return v/s; +} + +template<> +__device__ float4 __forceinline__ __vset(float x) { + return make_float4(x, x, x, x); +} + +__device__ float4 __forceinline__ __vmul(float4 a, float4 b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} + +__device__ float4 __forceinline__ __vadd(float4 a, float4 b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} + +__device__ float4 __forceinline__ __vsub(float4 a, float4 b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} + +__device__ float __forceinline__ __vred(float4 a) { + return a.x + a.y + a.z + a.w; +} + +__device__ float4 __forceinline__ __vscale(float s, float4 v) { + return make_float4(s*v.x, s*v.y, s*v.z, s*v.w); +} + +__device__ float4 __forceinline__ __vdiv(float s, float4 v) { + return make_float4(s/v.x, s/v.y, s/v.z, s/v.w);; +} + +template +__device__ VAL_T __warp_sum(VAL_T val) { + + #pragma unroll + for(int i = WARP_SIZE/2; i; i /= 2) { + val += __shfl_xor_sync(FULL_MASK, val, i, WARP_SIZE); + } + return val; +} + +template +__device__ VAL_T __block_sum(VAL_T val) { + + const int NWARP = (BDIM_X*BDIM_Y*BDIM_Z) / WARP_SIZE; + + val = __warp_sum(val); + + if constexpr(NWARP > 1) { + + int tid = threadIdx.x; + if constexpr(BDIM_Y > 1) { tid += threadIdx.y*BDIM_X; } + if constexpr(BDIM_Z > 1) { tid += threadIdx.z*BDIM_X*BDIM_Y; } + + const int lid = tid%WARP_SIZE; + const int wid = tid/WARP_SIZE; + + __shared__ VAL_T sh[NWARP]; + + if (lid == 0) { + sh[wid] = val; + } + __syncthreads(); + + if (wid == 0) { + val = (lid < NWARP) ? sh[lid] : 0; + + val = __warp_sum(val); + __syncwarp(); + + if (!lid) { + sh[0] = val; + } + } + __syncthreads(); + + val = sh[0]; + __syncthreads(); + } + return val; +} + +// transpose utils +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0231_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int coff = blockIdx.x*BDIM_X; // channel offset + const int woff = blockIdx.y*BDIM_X; // width offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][coff + j+tidy][h][woff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (coff + j+tidy < nchn) ? src[batch][coff + j+tidy][h][woff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (woff + j+tidy < nlon) { + dst[batch][h][woff + j+tidy][coff+tidx] = sh[tidx][j+tidy]; + } + } + } + } + return; +} + +template +void launch_permute_to0231(at::Tensor src, at::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(1), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(2)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0231_k + <<>>(src.size(1), + src.size(2), + src.size(3), + src.packed_accessor32(), + dst.packed_accessor32()); +} + +template +__global__ +__launch_bounds__(BDIM_X*BDIM_Y) +void permute_to0312_k(const int nchn, + const int nlat, + const int nlon, + const at::PackedTensorAccessor32 src, + at::PackedTensorAccessor32 dst) { + + static_assert(!(BDIM_X & (BDIM_X-1))); + static_assert(!(BDIM_Y & (BDIM_Y-1))); + static_assert(BDIM_X >= BDIM_Y); + + __shared__ VAL_T sh[BDIM_X][BDIM_X+1]; + + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + + const int woff = blockIdx.x*BDIM_X; // width offset + const int coff = blockIdx.y*BDIM_X; // channel offset + const int batch = blockIdx.z / nlat; // batch (same for all block) + const int h = blockIdx.z - (batch * nlat); // height (same for all block) + + const int nchn_full = (nchn-coff) >= BDIM_X; + const int nlon_full = (nlon-woff) >= BDIM_X; + + if (nchn_full && nlon_full) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = src[batch][h][woff + j+tidy][coff+tidx]; + } + __syncthreads(); + + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy]; + } + } else { + if (coff+tidx < nchn) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + sh[j+tidy][tidx] = (woff + j+tidy < nlon) ? src[batch][h][woff + j+tidy][coff+tidx] : VAL_T(0); + } + } + __syncthreads(); + + if (woff+tidx < nlon) { + #pragma unroll + for(int j = 0; j < BDIM_X; j += BDIM_Y) { + if (coff + j+tidy < nchn) { + dst[batch][coff + j+tidy][h][woff+tidx] = sh[tidx][j+tidy];; + } + } + } + } + return; +} + +template +void launch_permute_to0312(at::Tensor src, at::Tensor dst){ + dim3 block; + dim3 grid; + + block.x = WARP_SIZE; + block.y = WARPS_X_TILE; + grid.x = DIV_UP(src.size(2), block.x); + grid.y = DIV_UP(src.size(3), block.x); + grid.z = src.size(1)*src.size(0); + + assert(grid.y < 65536); + assert(grid.z < 65536); + + // get stream + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + permute_to0312_k + <<>>(src.size(3), + src.size(1), + src.size(2), + src.packed_accessor32(), + dst.packed_accessor32()); +} diff --git a/torch_harmonics_attn/cudamacro.h b/torch_harmonics_attn/cudamacro.h new file mode 100644 index 0000000000000000000000000000000000000000..0edef184557a3e82a51b471f4154d8a58f1d2343 --- /dev/null +++ b/torch_harmonics_attn/cudamacro.h @@ -0,0 +1,47 @@ +// coding=utf-8 +// +// SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#define CHECK_CUDA(call) { \ + cudaError_t err = call; \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error in file '%s' in line %i : %s.\n", \ + __FILE__, __LINE__, cudaGetErrorString( err) ); \ + exit(EXIT_FAILURE); \ + }} + +#define CHECK_ERROR(errorMessage) { \ + cudaError_t err = cudaGetLastError(); \ + if( cudaSuccess != err) { \ + fprintf(stderr, "Cuda error: %s in file '%s' in line %i : %s.\n", \ + errorMessage, __FILE__, __LINE__, cudaGetErrorString( err) );\ + exit(EXIT_FAILURE); \ + }}