diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h new file mode 100644 index 0000000000000000000000000000000000000000..1701158b0bdd479cb179e4d0162c78ab335aba8a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm50.h @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/functional.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = double; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + + d[0] = a[0] * b[0] + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + float, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0] + c[0].real(); + d[0].imag() = a[0].imag() * b[0] + c[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + float, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0] * b[0].real() + c[0].real(); + d[0].imag() = a[0] * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0].real() + c[0].real(); + d[0].imag() = a[0].imag() * b[0].real() + c[0].imag(); + d[0].real() = -a[0].imag() * b[0].imag() + d[0].real(); + d[0].imag() = a[0].real() * b[0].imag() + d[0].imag(); + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + complex, + LayoutA, + double, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array, 1> const &a, + Array const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0].real() * b[0] + c[0].real(); + d[0].imag() = a[0].imag() * b[0] + c[0].imag(); + } +}; + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma< + gemm::GemmShape<1, 1, 1>, + 1, + double, + LayoutA, + complex, + LayoutB, + complex, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; + using ElementC = complex; + + CUTLASS_HOST_DEVICE + void operator()( + Array, 1> &d, + Array const &a, + Array, 1> const &b, + Array, 1> const &c + ) { + + d[0].real() = a[0] * b[0].real() + c[0].real(); + d[0].imag() = a[0] * b[0].imag() + d[0].imag(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = float; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + d[0] = float(a[0]) * float(b[0]) + c[0]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation for Quaternions +template < + /// Layout of A matrix + typename LayoutA, + /// Layout of B matrix + typename LayoutB, + /// Layout of C matrix + typename LayoutC +> +struct Mma, 1, Quaternion, LayoutA, Quaternion, LayoutB, Quaternion, LayoutC, OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; + using Element = Quaternion; + using ElementC = Element; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + multiply_add op; + d[0] = op(a[0], b[0], c[0]); + } + +}; + +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h new file mode 100644 index 0000000000000000000000000000000000000000..31ef2b653076863cfb9387ba078d31ee8b52d607 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm60.h @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include + +#include "cutlass/arch/mma.h" + +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<2,1,1>, + 1, + half_t, + LayoutA, + half_t, + LayoutB, + half_t, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 B = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 D = __hfma2(A, B, C); + + d = reinterpret_cast &>(D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[0] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1,2,1>, + 1, + half_t, + LayoutA, + half_t, + LayoutB, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 D = __hfma2(A, B, C); + + d = reinterpret_cast &>(D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[0] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma < + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b)); + __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b)); + + __half2 const *C = reinterpret_cast<__half2 const *>(&c); + + __half2 Dlo = __hfma2(A, Blo, C[0]); + __half2 Dhi = __hfma2(A, Bhi, C[1]); + + Array * D = reinterpret_cast *>(&d); + + D[0] = reinterpret_cast const &>(Dlo); + D[1] = reinterpret_cast const &>(Dhi); + +#else + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 2; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j]; + } + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a)); + __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a)); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + + __half2 const *C = reinterpret_cast<__half2 const *>(&c); + + __half2 Dlo = __hfma2(Alo, B, C[0]); + __half2 Dhi = __hfma2(Ahi, B, C[1]); + + Array * D = reinterpret_cast *>(&d); + + D[0] = reinterpret_cast &>(Dlo); + D[1] = reinterpret_cast &>(Dhi); +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 2; ++j) { + d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j]; + } + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h new file mode 100644 index 0000000000000000000000000000000000000000..b780335efadeecee07f7c1c98422f18fec6f7ea3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm61.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1,1,4>, + 1, + int8_t, + LayoutA, + int8_t, + LayoutB, + int, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 4>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) + + unsigned const &A = reinterpret_cast(a); + unsigned const &B = reinterpret_cast(b); + + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d[0]) + : "r"(A), "r"(B), "r"(c[0])); + +#else + + d[0] = c[0]; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < 4; ++k) { + d[0] += a[k] * b[k]; + } + +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template +struct Mma< + gemm::GemmShape<1, 1, 2>, + 1, + int16_t, + layout::RowMajor, + int16_t, + layout::ColumnMajor, + int, + LayoutC, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<1, 1, 2>; + using Operator = OpMultiplyAdd; + using ElementC = int; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) + + unsigned const &A = reinterpret_cast(a); + unsigned const &B = reinterpret_cast(b); + + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d[0]) + : "r"(A), "r"(B), "r"(c[0])); +#else + d[0] = c[0]; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < 2; ++k) { + d[0] += a[k] * b[k]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..6acdcfac3b9d3d10253d3a343a1d097b617ddb16 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm70.h @@ -0,0 +1,661 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +#define CUTLASS_ARCH_MMA_SM70_SUPPORTED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) +#define CUTLASS_ARCH_MMA_SM70_ENABLED +#endif + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Matrix multiply accumulate 884 - FP16 accumulation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8,8,4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Matrix multiply accumulate 884 - FP32 accumulation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::ColumnMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + layout::RowMajor, + half_t, + layout::RowMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8, 8, 4>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::RowMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm70; + + /// Multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) { + +#if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), + "=f"(D[1]), + "=f"(D[2]), + "=f"(D[3]), + "=f"(D[4]), + "=f"(D[5]), + "=f"(D[6]), + "=f"(D[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7]) + ); + +#else + assert(0); + #if defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); + #endif +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation specialized for the entire warp +template < + typename LayoutA, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename Operator +> +struct Mma< + gemm::GemmShape<16, 16, 4>, + 32, + half_t, + LayoutA, + half_t, + LayoutB, + ElementC, + LayoutC, + Operator +> : + public Mma< + gemm::GemmShape<8, 8, 4>, + 8, + half_t, + LayoutA, + half_t, + LayoutB, + ElementC, + LayoutC, + Operator> { + + using Shape = gemm::GemmShape<16, 16, 4>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..c71ea076b5c2390cea8b0ba17ae1b642c5d49b48 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm75.h @@ -0,0 +1,789 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply for SM75 +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "cutlass/arch/wmma.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +// CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply. +#include +#include "cutlass/wmma_array.h" +#endif + +// CUTLASS includes +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) + +#define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +#define CUTLASS_ARCH_MMA_SM75_ENABLED +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - FP16 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + unsigned const *C = reinterpret_cast(&c); + unsigned *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const *A = reinterpret_cast(&a); + unsigned const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply (8b) with SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Integer matrix multiply (4b) - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + int4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<8, 8, 32>, + 32, + uint4b_t, + layout::RowMajor, + uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<8, 8, 32>; + + using ElementA = uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + + unsigned const & A = reinterpret_cast(a); + unsigned const & B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// b1 ^ b1 + s32 => s32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template <> +struct Mma< + gemm::GemmShape<8,8,128>, + 32, + uint1b_t, + layout::RowMajor, + uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpXorPopc> { + + using Shape = gemm::GemmShape<8,8,128>; + + using ElementA = uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpXorPopc; + using ArchTag = arch::Sm75; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + using WmmaFragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + nvcuda::wmma::experimental::precision::b1, + nvcuda::wmma::row_major>; + + using WmmaFragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + nvcuda::wmma::experimental::precision::b1, + nvcuda::wmma::col_major>; + + using WmmaFragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + int>; + + WmmaFragmentA const & A = reinterpret_cast(a); + WmmaFragmentB const & B = reinterpret_cast(b); + + WmmaFragmentC const & C = reinterpret_cast(c); + WmmaFragmentC & D = reinterpret_cast(d); + + nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, + nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. + +#endif // defined(CUTLASS_ARCH_WMMA_ENABLED) + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..22cd87d65b0412e9ac9a4953feee022c5e5feb92 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm80.h @@ -0,0 +1,1500 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + +#define CUTLASS_ARCH_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED +#endif +#if (__CUDA_ARCH__ <= 890) +#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 8>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1684 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 4>, + 32, + tfloat32_t, + layout::RowMajor, + tfloat32_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 4>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), + "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 1688 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct Mma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd> { + using Shape = gemm::GemmShape<16, 8, 8>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + bfloat16_t, + layout::RowMajor, + bfloat16_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 16>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 884 - F64 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<8,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<8,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm80; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + double const & A = reinterpret_cast(a); + double const & B = reinterpret_cast(b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 {%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=d"(D[0]), "=d"(D[1]) + : "d"(A), "d"(B), "d"(C[0]), "d"(C[1])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const &B = reinterpret_cast(b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), + "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,32>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16,8,32>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct Mma< + gemm::GemmShape<16, 8, 64>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate> { + + using Shape = gemm::GemmShape<16, 8, 64>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int32_t, + layout::RowMajor, + OpAndPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpAndPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - XOR,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpXorPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpXorPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); + +#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h new file mode 100644 index 0000000000000000000000000000000000000000..4bcd9bc1de9b6e53629e08f478a50d791d198a1a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm89.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Matrix multiply-accumulate specialzied for SM89 +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +# define CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED +#endif + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +# define CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) +# define CUTLASS_ARCH_MMA_F32_SM89_ENABLED +# endif + +# if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) +# define CUTLASS_ARCH_MMA_F16_SM89_ENABLED +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Whether the Mma uses as SM89 staged accumulation policy +template +static constexpr bool is_sm89_staged_policy_v = + ( + // ElementA must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // ElementB must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // The instruction shape must be 16x8x32 + Operator::ArchMmaOperator::Shape::kM == 16 && + Operator::ArchMmaOperator::Shape::kN == 8 && + Operator::ArchMmaOperator::Shape::kK == 32 + ) && + ( + // The operator must be OpMultiplyAdd (default) + platform::is_same::value + ); +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP16 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F16 = fe4m3 * fe4m3 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe4m3 * fe5m2 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe5m2 * fe4m3 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe5m2 * fe5m2 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h new file mode 100644 index 0000000000000000000000000000000000000000..b135c8645b48eb40a1cce88c515074e95d4b6a5e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sm90.h @@ -0,0 +1,241 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/config.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x4 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[1]), + "d"(B[0]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x8 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,8>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,8>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) + : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), + "d"(B[0]), "d"(B[1]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x16 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), + "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..e4ca91a10293334fbd89e21891132442a6216e6a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm80.h @@ -0,0 +1,1234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM80 +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) + +#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED +#endif + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread +> { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#endif + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread + > { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } + +#endif + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct SparseMma, 32, bfloat16_t, layout::RowMajor, + bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16816 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct SparseMma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 4; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAddSaturate; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif + +#else + + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h new file mode 100644 index 0000000000000000000000000000000000000000..6adca25527efdc1c3cb564b4553d96bebe59b3fd --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/mma_sparse_sm89.h @@ -0,0 +1,406 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM89 +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED) +# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED +# endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h new file mode 100644 index 0000000000000000000000000000000000000000..93dd37d3193867602d69866cb2cfcd2e27e87f62 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/reg_reconfig.h @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for CTA Reconfiguration +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#ifndef CUDA_CTA_RECONFIG_ACTIVATED + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ + (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \ + || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \ + || (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \ + || (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \ + || (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \ + || (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \ + ) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif + + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ + (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ + || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ + || (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \ + || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ + || (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \ + ) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif + +#endif + +namespace cutlass { +namespace arch { + +template +CUTLASS_DEVICE +void warpgroup_reg_alloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +template +CUTLASS_DEVICE +void warpgroup_reg_dealloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h new file mode 100644 index 0000000000000000000000000000000000000000..a1dc7dff4d603ecf7e6a190c84bc7634e8c8be62 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators +*/ + +#pragma once + +#include "cutlass/arch/array.h" +#include "cutlass/arch/numeric_types.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Element-wise operators +// + +CUTLASS_HOST_DEVICE +template +Array operator*(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] * b[i]; + } + return d; +} + +CUTLASS_HOST_DEVICE +template +Array operator+(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] + b[i]; + } + return d; +} + +CUTLASS_HOST_DEVICE +template +Array operator-(Array const &a, Array const &b) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] - b[i]; + } + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Multiply-accumulate operators +// + +CUTLASS_HOST_DEVICE +template +Array mac(Array const &a, Array const &b, Array const &c) { + Array d; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Dot product operator +// + +CUTLASS_HOST_DEVICE +template +Accumulator dot(Array const &a, Array const &b, Accumulator accum) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + accum += a[i] * b[i]; + } + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "simd_sm60.h" +#include "simd_sm61.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h new file mode 100644 index 0000000000000000000000000000000000000000..59f38d62da91ab9af6a1f73a5990d29056dd259a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm60.h @@ -0,0 +1,104 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators for SM60 +*/ + +#pragma once + +#include "simd.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Element-wise operators - specialized for half_t x 2 +// + +CUTLASS_HOST_DEVICE +template <> +Array operator*(Array const &a, Array const &b) { + Array d; + + return d; +} + +CUTLASS_HOST_DEVICE +template <> +Array operator+(AArray const &a, Array const &b) { + Array d; + + return d; +} + +CUTLASS_HOST_DEVICE +template <> +Array operator-(Array const &a, Array const &b) { + Array d; + + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiply-accumulate operators - specialized for half_t x 2 +CUTLASS_HOST_DEVICE +template <> +Array mac(Array const &a, Array const &b, Array const &c) { + Array d; + + return d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for half_t <- (half_t * half_t) x 2 + half_t +CUTLASS_HOST_DEVICE +template <> +half_t dot(Array const &a, Array const &b, half_t accum) { + + return accum; +} + +/// Dot product operator - specialized for float <- (half_t * half_t) x 2 + float +CUTLASS_HOST_DEVICE +template <> +float dot(Array const &a, Array const &b, float accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h new file mode 100644 index 0000000000000000000000000000000000000000..46c22665c2126b5dd2e0fb143be00143b933f3ec --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/simd_sm61.h @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing SIMD operators for SM61 +*/ + +#pragma once + +#include "simd.h" + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int8_t * int8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint8_t * int8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int8_t * uint8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint8_t * uint8_t) x 4 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int8_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (int16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +/// Dot product operator - specialized for int32_t <- (uint16_t * int16_t) x 2 + int32_t +CUTLASS_HOST_DEVICE +template <> +int32_t dot(Array const &a, Array const &b, int32_t accum) { + + return accum; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5567fe561f8ca7a95f7b0958aaced2696109f22a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/synclog.hpp @@ -0,0 +1,1271 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Synchronization event logging for race condition debugging. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#endif + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) + +constexpr uint32_t synclog_cap = 1 << 26; + +inline std::mutex synclog_mutex; +inline std::vector synclog_buf_list; +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +CUTLASS_DEVICE uint32_t* synclog_buf; +#endif + +CUTLASS_DEVICE +uint32_t* synclog_alloc(uint32_t n) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t* buf = synclog_buf; + if (buf == nullptr) return nullptr; + uint32_t last = atomicAdd(&buf[0], n); + if (last + n < synclog_cap) return buf + last + 1; + if (last >= synclog_cap) atomicAdd(&buf[0], -n); + #endif + return nullptr; +} + +CUTLASS_DEVICE +void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint64_t time64; + asm volatile ( + "mov.u64 %0, %%globaltimer;\n" + : "=l"(time64) : + ); + to[0] = header; + to[1] = line; + to[2] = time64; + to[3] = time64 >> 32; + to[4] = threadIdx.x; + to[5] = threadIdx.y; + to[6] = threadIdx.z; + to[7] = blockIdx.x; + to[8] = blockIdx.y; + to[9] = blockIdx.z; + #endif +} + +constexpr uint32_t synclog_header_none = 0; +constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; + +constexpr bool synclog_enable_syncthreads = true; +constexpr uint32_t synclog_header_syncthreads = 1; +constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; + +constexpr bool synclog_enable_syncwarp = true; +constexpr uint32_t synclog_header_syncwarp = 2; +constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; + +constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; +constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; +constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; + +constexpr bool synclog_enable_named_barrier_arrive = true; +constexpr uint32_t synclog_header_named_barrier_arrive = 4; +constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_init = true; +constexpr uint32_t synclog_header_cluster_barrier_init = 5; +constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_wait = 6; +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2; +constexpr bool synclog_enable_cluster_barrier_test_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3; +constexpr bool synclog_enable_cluster_barrier_try_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2; +constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3; +constexpr bool synclog_enable_cluster_barrier_arrive = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1; +constexpr bool synclog_enable_cluster_barrier_invalidate = true; +constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1; +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2; +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4; +constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2; +constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4; +constexpr bool synclog_enable_fence_barrier_init = true; +constexpr uint32_t synclog_header_fence_barrier_init = 16; +constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; + +constexpr bool synclog_enable_fence_view_async_shared = true; +constexpr uint32_t synclog_header_fence_view_async_shared = 17; +constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_wait = true; +constexpr uint32_t synclog_header_cp_async_wait = 18; +constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_cp_async_wait_all = true; +constexpr uint32_t synclog_header_cp_async_wait_all = 19; +constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_fence = true; +constexpr uint32_t synclog_header_cp_async_fence = 20; +constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_nan = true; +constexpr uint32_t synclog_header_cp_async_nan = 21; +constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cp_async_zfill = true; +constexpr uint32_t synclog_header_cp_async_zfill = 22; +constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cp_async = true; +constexpr uint32_t synclog_header_cp_async = 23; +constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; + +constexpr bool synclog_enable_tma_load = true; +constexpr uint32_t synclog_header_tma_load = 24; +constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; + +constexpr bool synclog_enable_tma_store = true; +constexpr uint32_t synclog_header_tma_store = 25; +constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; + +constexpr bool synclog_enable_tma_store_arrive = true; +constexpr uint32_t synclog_header_tma_store_arrive = 26; +constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_tma_store_wait = true; +constexpr uint32_t synclog_header_tma_store_wait = 27; +constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_arrive = true; +constexpr uint32_t synclog_header_warpgroup_arrive = 28; +constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_warpgroup_wait = true; +constexpr uint32_t synclog_header_warpgroup_wait = 29; +constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_commit_batch = true; +constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; +constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; + +constexpr bool synclog_enable_wgmma_reg_smem = true; +constexpr uint32_t synclog_header_wgmma_reg_smem = 31; +constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; + +constexpr bool synclog_enable_wgmma_smem_smem = true; +constexpr uint32_t synclog_header_wgmma_smem_smem = 32; +constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cpasync_barrier_arrive = true; +constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1; +CUTLASS_DEVICE +bool synclog_condition_emit() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return 0; + #endif +} + +CUTLASS_DEVICE +bool synclog_condition_print() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return false; + #endif +} + +CUTLASS_DEVICE +void synclog_print_prefix(char const* header, uint32_t at) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t line = synclog_buf[at + 1]; + uint32_t timeLo = synclog_buf[at + 2]; + uint32_t timeHi = synclog_buf[at + 3]; + uint32_t threadIdxX = synclog_buf[at + 4]; + uint32_t threadIdxY = synclog_buf[at + 5]; + uint32_t threadIdxZ = synclog_buf[at + 6]; + uint32_t blockIdxX = synclog_buf[at + 7]; + uint32_t blockIdxY = synclog_buf[at + 8]; + uint32_t blockIdxZ = synclog_buf[at + 9]; + printf( + "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", + header, line, + (uint64_t)timeHi << 32 | timeLo, + threadIdxX, threadIdxY, threadIdxZ, + blockIdxX, blockIdxY, blockIdxZ + ); + #endif +} + +CUTLASS_DEVICE +void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { + CUTLASS_UNUSED(hi); + uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; + printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); +} + +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void synclog_setup() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + std::scoped_lock lock(synclog_mutex); + auto fail = [] () { + fprintf(stderr, "synclog_setup() failed\n"); + std::terminate(); + }; + int orig_device = 0; + if (cudaGetDevice(&orig_device) != cudaSuccess) { + fail(); + } + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + fail(); + } + if (synclog_buf_list.size() == 0) { + for (int device = 0; device < device_count; device++) { + uint32_t* buf = 0; + if (cudaSetDevice(device) != cudaSuccess || + cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { + fail(); + } + synclog_buf_list.push_back(buf); + } + } + for (int device = 0; device < device_count; device++) { + uint32_t* buf = synclog_buf_list.at(device); + if (cudaSetDevice(device) != cudaSuccess || + cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || + cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { + fail(); + } + } + if (cudaSetDevice(orig_device) != cudaSuccess) { + fail(); + } + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncthreads(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncthreads) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncthreads); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncthreads, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncwarp(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncwarp) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncwarp); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncwarp, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive_and_wait( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_init( + uint32_t line, + uint32_t smem_addr, + uint32_t arrive_count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = arrive_count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(arrive_count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_test_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_test_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_try_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_try_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = cta_id; + to[synclog_length_prefix + 2] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_invalidate( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_invalidate) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); + to[synclog_length_prefix + 0] = smem_addr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = cta_id; + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_expect_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_complete_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t dst_cta_id, + uint32_t transaction_bytes, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = dst_cta_id; + to[synclog_length_prefix + 2] = transaction_bytes; + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(dst_cta_id); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_barrier_init(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_view_async_shared(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_async_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait_all(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait_all) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_fence(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_fence) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_fence, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_nan( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_nan) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_nan, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_zfill( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_zfill) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_load( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_mbar, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_load) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_load); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_load, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_mbar; + to[synclog_length_prefix + 3] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_mbar); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_arrive(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_wait( + uint32_t line, + uint32_t count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_wait, line); + to[synclog_length_prefix + 0] = count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_arrive( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_commit_batch( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_commit_batch) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_reg_smem( + uint32_t line, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_reg_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); + to[synclog_length_prefix + 0] = desc_b; + to[synclog_length_prefix + 1] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_smem_smem( + uint32_t line, + uint64_t desc_a, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_smem_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); + to[synclog_length_prefix + 0] = desc_a; + to[synclog_length_prefix + 1] = desc_a >> 32; + to[synclog_length_prefix + 2] = desc_b; + to[synclog_length_prefix + 3] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_a); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cpasync_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cpasync_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +#if !defined(CUTLASS_ENABLE_SYNCLOG) +CUTLASS_DEVICE +#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +static __attribute__((__noinline__)) __device__ +#else +static __attribute__((__noinline__)) +#endif +void synclog_print() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + if (synclog_buf == nullptr || !synclog_condition_print()) { + return; + } + printf("synclog start\n"); + for (uint32_t at = 1; at < synclog_cap; ) { + uint32_t header = synclog_buf[at]; + if (header == synclog_header_none) { + break; + } + printf("synclog at %u: ", at); + if constexpr (synclog_enable_syncthreads) { + if (header == synclog_header_syncthreads) { + synclog_print_prefix("syncthreads", at); + at += synclog_length_syncthreads; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_syncwarp) { + if (header == synclog_header_syncwarp) { + synclog_print_prefix("syncwarp", at); + at += synclog_length_syncwarp; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive_and_wait) { + if (header == synclog_header_named_barrier_arrive_and_wait) { + synclog_print_prefix("named_barrier_arrive_and_wait", at); + at += synclog_length_named_barrier_arrive_and_wait; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive) { + if (header == synclog_header_named_barrier_arrive) { + synclog_print_prefix("named_barrier_arrive", at); + at += synclog_length_named_barrier_arrive; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_init) { + if (header == synclog_header_cluster_barrier_init) { + synclog_print_prefix("cluster_barrier_init", at); + at += synclog_length_cluster_barrier_init; + printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_wait) { + if (header == synclog_header_cluster_barrier_wait) { + synclog_print_prefix("cluster_barrier_wait", at); + at += synclog_length_cluster_barrier_wait; + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_test_wait) { + if (header == synclog_header_cluster_barrier_test_wait) { + synclog_print_prefix("cluster_barrier_test_wait", at); + at += synclog_length_cluster_barrier_test_wait; + printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_try_wait) { + if (header == synclog_header_cluster_barrier_try_wait) { + synclog_print_prefix("cluster_barrier_try_wait", at); + at += synclog_length_cluster_barrier_try_wait; + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { + if (header == synclog_header_cluster_barrier_arrive_cluster) { + synclog_print_prefix("cluster_barrier_arrive_cluster", at); + at += synclog_length_cluster_barrier_arrive_cluster; + printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive) { + if (header == synclog_header_cluster_barrier_arrive) { + synclog_print_prefix("cluster_barrier_arrive", at); + at += synclog_length_cluster_barrier_arrive; + printf("smem_addr=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_invalidate) { + if (header == synclog_header_cluster_barrier_invalidate) { + synclog_print_prefix("cluster_barrier_invalidate", at); + at += synclog_length_cluster_barrier_invalidate; + printf("smem_addr=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { + if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { + synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); + at += synclog_length_cluster_transaction_barrier_expect_transaction; + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { + if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { + synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); + at += synclog_length_cluster_transaction_barrier_complete_transaction; + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_fence_barrier_init) { + if (header == synclog_header_fence_barrier_init) { + synclog_print_prefix("fence_barrier_init", at); + at += synclog_length_fence_barrier_init; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_fence_view_async_shared) { + if (header == synclog_header_fence_view_async_shared) { + synclog_print_prefix("fence_view_async_shared", at); + at += synclog_length_fence_view_async_shared; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait) { + if (header == synclog_header_cp_async_wait) { + synclog_print_prefix("cp_async_wait", at); + at += synclog_length_cp_async_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait_all) { + if (header == synclog_header_cp_async_wait_all) { + synclog_print_prefix("cp_async_wait_all", at); + at += synclog_length_cp_async_wait_all; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_fence) { + if (header == synclog_header_cp_async_fence) { + synclog_print_prefix("cp_async_fence", at); + at += synclog_length_cp_async_fence; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_nan) { + if (header == synclog_header_cp_async_nan) { + synclog_print_prefix("cp_async_nan", at); + at += synclog_length_cp_async_nan; + uint64_t gmem_addr = synclog_buf[at-3]; + gmem_addr += (uint64_t)synclog_buf[at-2] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_zfill) { + if (header == synclog_header_cp_async_zfill) { + synclog_print_prefix("cp_async_zfill", at); + at += synclog_length_cp_async_zfill; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async) { + if (header == synclog_header_cp_async) { + synclog_print_prefix("cp_async", at); + at += synclog_length_cp_async; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_load) { + if (header == synclog_header_tma_load) { + synclog_print_prefix("tma_load", at); + at += synclog_length_tma_load; + uint64_t gmem_int_desc = synclog_buf[at-4]; + gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32; + printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store) { + if (header == synclog_header_tma_store) { + synclog_print_prefix("tma_store", at); + at += synclog_length_tma_store; + uint64_t gmem_int_desc = synclog_buf[at-3]; + gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32; + printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store_arrive) { + if (header == synclog_header_tma_store_arrive) { + synclog_print_prefix("tma_store_arrive", at); + at += synclog_length_tma_store_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_tma_store_wait) { + if (header == synclog_header_tma_store_wait) { + synclog_print_prefix("tma_store_wait", at); + at += synclog_length_tma_store_wait; + printf("count=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_arrive) { + if (header == synclog_header_warpgroup_arrive) { + synclog_print_prefix("warpgroup_arrive", at); + at += synclog_length_warpgroup_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_warpgroup_wait) { + if (header == synclog_header_warpgroup_wait) { + synclog_print_prefix("warpgroup_wait", at); + at += synclog_length_warpgroup_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_commit_batch) { + if (header == synclog_header_warpgroup_commit_batch) { + synclog_print_prefix("warpgroup_commit_batch", at); + at += synclog_length_warpgroup_commit_batch; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_reg_smem) { + if (header == synclog_header_wgmma_reg_smem) { + synclog_print_prefix("wgmma_reg_smem", at); + at += synclog_length_wgmma_reg_smem; + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_smem_smem) { + if (header == synclog_header_wgmma_smem_smem) { + synclog_print_prefix("wgmma_smem_smem", at); + at += synclog_length_wgmma_smem_smem; + synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " "); + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cpasync_barrier_arrive) { + if (header == synclog_header_cpasync_barrier_arrive) { + synclog_print_prefix("cpasync_barrier_arrive", at); + at += synclog_length_cpasync_barrier_arrive; + printf("smem_addr=%u\n", synclog_buf[at-1]); + continue; + } + } + asm volatile ("brkpt;\n" ::); + } + if (synclog_buf[0] >= synclog_cap) { + printf( + "synclog was truncated (exceeded capacity of %lu bytes)\n", + (synclog_cap - 1) * sizeof(uint32_t) + ); + } + printf("synclog end\n"); + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncthreads +#define __syncthreads() do {\ + cutlass::arch::synclog_emit_syncthreads(__LINE__);\ + __syncthreads();\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncwarp +#define __syncwarp(...) do {\ + cutlass::arch::synclog_emit_syncwarp(__LINE__);\ + __syncwarp(__VA_ARGS__);\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h new file mode 100644 index 0000000000000000000000000000000000000000..2d4861ab682aca73d40ad5d0f298f9a265f7b9f2 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for warp matrix multiply-add (WMMA) operations +*/ + +#pragma once + +#if (__CUDACC_VER_MAJOR__ >= 9) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) +#define CUTLASS_ARCH_WMMA_ENABLED +#define CUTLASS_ARCH_WMMA_SM70_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 10) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) +#define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED +#define CUTLASS_ARCH_WMMA_SM72_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 10) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED +#define CUTLASS_ARCH_WMMA_SM75_ENABLED +#endif +#endif + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +#include +#include "cutlass/arch/mma.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass data types => nvcuda::wmma data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassToWmmaDataType{ + using Type = Type_; +}; + +/// Statically maps cutlass::half_t => __half +template<> +struct CutlassToWmmaDataType { + using Type = __half; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) +template<> +struct CutlassToWmmaDataType { + using Type = __nv_bfloat16; +}; +#endif + +/// Statically maps int8_t => char +template<> +struct CutlassToWmmaDataType { + using Type = signed char; +}; + +/// Statically maps uint8_t => char +template<> +struct CutlassToWmmaDataType { + using Type = unsigned char; +}; + +/// Statically maps int32_t => int +template<> +struct CutlassToWmmaDataType { + using Type = int; +}; + +#if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) +/// Statically maps cutlass::int4b_t => experimental::precision::s4 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::s4; +}; + +/// Statically maps cutlass::uint4b_t => experimental::precision::s4 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::u4; +}; + +/// Statically maps cutlass::uint1b_t => experimental::precision::b1 +template<> +struct CutlassToWmmaDataType { + using Type = nvcuda::wmma::experimental::precision::b1; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass::layout => nvcuda::wmma layout tags +//////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassToWmmaLayout { +}; + +/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags +template <> +struct CutlassToWmmaLayout { + using Layout = nvcuda::wmma::row_major; + static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass::layout::RowMajor => nvcuda::wmma::row_major layout tags +//////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct CutlassToWmmaLayout { + using Layout = nvcuda::wmma::col_major; + static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major; +}; +//////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps nvcuda::wmma data types => cutlass data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct WmmaToCutlassDataType{ + using Type = Type_; +}; + +/// Statically maps __half => cutlass::half_t +template<> +struct WmmaToCutlassDataType<__half> { + using Type = cutlass::half_t; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) +template<> +struct WmmaToCutlassDataType<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks +// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]), +// and native wmma size (Shape) +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, ///< Size of the matrix product (concept: GemmShape) + typename ElementA_, ///< Data type of A elements + typename LayoutA_, ///< Layout of A matrix (concept: MatrixLayout) + typename ElementB_, ///< Data type of B elements + typename LayoutB_, ///< Layout of B matrix (concept: MatrixLayout) + typename ElementC_, ///< Element type of C matrix + typename LayoutC_, /// Layout of C matrix (concept: MatrixLayout) + typename Operator_ = cutlass::arch::OpMultiplyAdd ///< Inner product operator (multiply-add, xor.popc) +> +struct Wmma; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Specializations for each compute capability +// +#ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED +#include "cutlass/arch/wmma_sm70.h" +#endif + +#ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED +#include "cutlass/arch/wmma_sm72.h" +#endif + +#ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED +#include "cutlass/arch/wmma_sm75.h" +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif //CUTLASS_ARCH_WMMA_ENABLED diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..2c540be88577b448a2abc75cf6478736a41eb716 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm70.h @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for half +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename ElementC_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::half_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::half_t, ///< ElementB + LayoutB_, ///< LayoutB + ElementC_, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { + +#if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::half_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::half_t; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm70; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + // check supported wmma output data type for the given multiplicand data types + static_assert( + platform::is_same::value || platform::is_same::value, + "Supported of wmma output data type for f16 multiplicands are: f16 and f32"); + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } +#else + static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h new file mode 100644 index 0000000000000000000000000000000000000000..1eb553e8f311e66e08e47dab15c6b08c29dec81c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm72.h @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for int8_t +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + int8_t, ///< ElementA + LayoutA_, ///< LayoutA + int8_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) + using Shape = Shape_; + using ElementA = int8_t; + using LayoutA = LayoutA_; + using ElementB = int8_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } + +#else + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); +#endif + +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for uint8_t +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + uint8_t, ///< ElementA + LayoutA_, ///< LayoutA + uint8_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED) + using Shape = Shape_; + using ElementA = uint8_t; + using LayoutA = LayoutA_; + using ElementB = uint8_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm72; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value || + platform::is_same, Shape>::value || + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x32x16, and 32x8x16"); + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + + nvcuda::wmma::mma_sync(D, A, B, C); + } + +#else + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..c3535ef0748e53b204b7d20cdd4aa82edc8c72a8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/arch/wmma_sm75.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/layout/matrix.h" + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4). +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::int4b_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::int4b_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpMultiplyAdd ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::int4b_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::int4b_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpMultiplyAdd; + using ArchTag = arch::Sm75; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + nvcuda::wmma::mma_sync(D, A, B, C); + + } + +#else + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); +#endif + +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// WMMA template structure defines nvcuda::wmma::fragments and static assert for +// wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1). +// +//////////////////////////////////////////////////////////////////////////////// +template < +typename Shape_, +typename LayoutA_, +typename LayoutB_, +typename LayoutC_> +struct Wmma< + Shape_, ///< Size of the matrix product (concept: GemmShape) + cutlass::uint1b_t, ///< ElementA + LayoutA_, ///< LayoutA + cutlass::uint1b_t, ///< ElementB + LayoutB_, ///< LayoutB + int32_t, ///< ElementC + LayoutC_, ///< LayoutC + cutlass::arch::OpXorPopc ///< Operator (multiply-add, xor.popc) +> { +#if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) + using Shape = Shape_; + using ElementA = cutlass::uint1b_t; + using LayoutA = LayoutA_; + using ElementB = cutlass::uint1b_t; + using LayoutB = LayoutB_; + using ElementC = int32_t; + using LayoutC = LayoutC_; + using Operator = cutlass::arch::OpXorPopc; + using ArchTag = arch::Sm75; + + // check supported wmma shape for the given multiplicand data types + static_assert( + platform::is_same, Shape>::value, + "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128"); + + + // Wmma Fragment + using FragmentA = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_a, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentB = nvcuda::wmma::fragment< + nvcuda::wmma::matrix_b, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type, + typename CutlassToWmmaLayout::Layout>; + + using FragmentC = nvcuda::wmma::fragment< + nvcuda::wmma::accumulator, + Shape::kM, + Shape::kN, + Shape::kK, + typename CutlassToWmmaDataType::Type>; + + /// Performs a nvcuda::wmma matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + FragmentA const &A, + FragmentB const &B, + FragmentC const &C) const { + nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, + nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + } + +#else + static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond"); +#endif + +}; + +} // namespace arch +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h new file mode 100644 index 0000000000000000000000000000000000000000..ce33110aa4f44e7deba56a5f9fe4db206a6889ce --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array.h @@ -0,0 +1,2860 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N, + bool RegisterSized = sizeof_bits::value >= 32 +> +struct Array; + +namespace detail { + +template +struct is_Array : platform::false_type {}; + +template < + typename T, + int N, + bool RegisterSized +> +struct is_Array > : platform::true_type {}; + +template +constexpr bool is_Array_v = is_Array::value; + +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the size of an Array<> in bits +template +struct sizeof_bits > { + static constexpr int value = sizeof(Array) * 8; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if the argument is a power of 2 +CUTLASS_HOST_DEVICE +constexpr bool ispow2(unsigned x) { + return x && (!(x & (x - 1))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the largest power of two not greater than the argument. +CUTLASS_HOST_DEVICE +constexpr unsigned floor_pow_2(unsigned x) { + return (x == 0 || ispow2(x)) ? x : ((floor_pow_2(x >> 1)) << 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N +> +struct Array { + + /// Storage type + using Storage = T; + + /// Element type + using Element = T; + + /// Number of storage elements + //static std::size_t const kStorageElements = N; + static constexpr size_t kStorageElements = N; + + /// Number of logical elements + static constexpr size_t kElements = N; + + // + // C++ standard members + // + + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type &reference; + typedef value_type const & const_reference; + typedef value_type *pointer; + typedef value_type const * const_pointer; + + // + // Iterators + // + + /// Bidirectional iterator over elements + class iterator { + + /// Pointer to object + T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + iterator(T *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T &operator*() const { + return *ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional constant iterator over elements + class const_iterator { + + /// Pointer to object + const T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + const_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + const_iterator(T const *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + const_iterator &operator++() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_iterator &operator--() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_iterator operator++(int) { + const_iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + const_iterator operator--(int) { + const_iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T const &operator*() const { + return *ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator==(const_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(const_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional iterator over elements + class reverse_iterator { + + /// Pointer to object + T *ptr_; + + public: + + CUTLASS_HOST_DEVICE + reverse_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + reverse_iterator(T *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + reverse_iterator &operator++() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + reverse_iterator &operator--() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + reverse_iterator operator++(int) { + iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + reverse_iterator operator--(int) { + iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T &operator*() const { + return *(ptr_ - 1); + } + + CUTLASS_HOST_DEVICE + bool operator==(reverse_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(reverse_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Bidirectional constant iterator over elements + class const_reverse_iterator { + + /// Pointer to object + T const *ptr_; + + public: + + CUTLASS_HOST_DEVICE + const_reverse_iterator(): ptr_(nullptr) { } + + CUTLASS_HOST_DEVICE + const_reverse_iterator(T const *_ptr): ptr_(_ptr) { } + + CUTLASS_HOST_DEVICE + const_reverse_iterator &operator++() { + --ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator &operator--() { + ++ptr_; + return *this; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator operator++(int) { + const_reverse_iterator ret(*this); + --ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator operator--(int) { + const_reverse_iterator ret(*this); + ++ptr_; + return ret; + } + + CUTLASS_HOST_DEVICE + T const &operator*() const { + return *(ptr_ - 1); + } + + CUTLASS_HOST_DEVICE + bool operator==(const_iterator const &other) const { + return ptr_ == other.ptr_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(const_iterator const &other) const { + return ptr_ != other.ptr_; + } + }; + + /// Internal storage + Storage storage[kElements]; + + /// Efficient clear method + CUTLASS_HOST_DEVICE + void clear() { + fill(T(0)); + } + + CUTLASS_HOST_DEVICE + reference at(size_type pos) { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + const_reference at(size_type pos) const { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + reference operator[](size_type pos) { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + const_reference operator[](size_type pos) const { + return reinterpret_cast(storage[pos]); + } + + CUTLASS_HOST_DEVICE + reference front() { + return reinterpret_cast(storage[0]); + } + + CUTLASS_HOST_DEVICE + const_reference front() const { + return reinterpret_cast(storage[0]); + } + + CUTLASS_HOST_DEVICE + reference back() { + return reinterpret_cast(storage[kStorageElements - 1]); + } + + CUTLASS_HOST_DEVICE + const_reference back() const { + return reinterpret_cast(storage[kStorageElements - 1]); + } + + CUTLASS_HOST_DEVICE + pointer data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + pointer raw_data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer raw_data() const { + return reinterpret_cast(storage); + } + + + CUTLASS_HOST_DEVICE + constexpr bool empty() const { + return !kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type max_size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + void fill(T const &value) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(kElements); ++i) { + storage[i] = static_cast(value); + } + } + + CUTLASS_HOST_DEVICE + iterator begin() { + return iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_iterator begin() const { + return cbegin(); + } + + CUTLASS_HOST_DEVICE + const_iterator cbegin() const { + return const_iterator(storage); + } + + CUTLASS_HOST_DEVICE + iterator end() { + return iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + const_iterator end() const { + return cend(); + } + + CUTLASS_HOST_DEVICE + const_iterator cend() const { + return const_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rbegin() { + return reverse_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator rbegin() const { + return crbegin(); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crbegin() const { + return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rend() { + return reverse_iterator(reinterpret_cast(storage)); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator rend() const { + return crend(); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crend() const { + return const_reverse_iterator(reinterpret_cast(storage)); + } + + // + // Comparison operators + // + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Factories +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x) { + return {x}; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y) { + return {x,y}; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z) { + return {x,y,z}; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z, Element w) { + return {x,y,z,w}; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct absolute_value_op< Array > { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + absolute_value_op scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; +template +struct minus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct multiplies> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum_absolute_value_reduction, PropogateNaN> { + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, Array const& rhs) const { + + T result = scalar; + maximum_absolute_value_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; + +template +struct scale> { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(T scaling_factor) : scaling_factor_(scaling_factor) { + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & rhs) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = rhs[i] * scaling_factor_; + } + + return result; + } +}; + +template +struct divides> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct reciprocal_approximate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct reciprocal_approximate_ftz> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate_ftz scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct maximum, PropagateNaN> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum, PropagateNaN> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum_with_nan_propagation> : minimum, true> +{}; + +template +struct negate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + negate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, T const &scalar) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], scalar); + } + + return result; + } + + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar_b, T const &scalar_c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar_b, scalar_c); + } + + return result; + } +}; + +/// Fused square-and-plus +template +struct square_and_plus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + multiply_add, Array, Array> ma_op; + return ma_op(rhs, rhs, lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &rhs) const { + plus> plus_op; + multiplies multiplies_op; + return plus_op(multiplies_op(rhs, rhs), lhs); + } +}; + +/// Inverse-square-root +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + Array result; + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + return result; + } +}; + +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & a) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = h2rsqrt(a_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half d_residual = hrsqrt(a_residual_ptr[N - 1]); + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); + } + + return result; + } +}; + + +template +struct conjugate > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + + conjugate conj_op; + + Array ca; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + ca[i] = conj_op(a[i]); + } + return ca; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations targeting SIMD instructions in device code. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs; + } + #endif + + return result; + } +}; + +template +struct minus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs; + } + #endif + + return result; + } +}; + +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmul( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmul( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs; + } + #endif + + return result; + } +}; + +template +struct divides> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs; + } + #endif + + return result; + } +}; + +template +struct negate> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hneg2(source_ptr[i]); + } + + if constexpr (N % 2) { + half_t x = -lhs[N - 1]; + __half lhs_val = reinterpret_cast<__half const &>(x); + result[N - 1] = reinterpret_cast(lhs_val); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = -lhs[i]; + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c); + } + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma_relu( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a, b[i], c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b, c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c), half_t(0)); + } + #endif + + return result; + } +}; + +template +struct minimum, PropagateNaN> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmin2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs[i],rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i]) + : __hmin2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs, rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair) + : __hmin2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs[i], rhs); + } + #endif + + return result; + } +}; + +template +struct maximum, PropagateNaN> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs[i], rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i]) + : __hmax2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs, rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair) + : __hmax2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs[i], rhs); + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + bfloat16_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned a_packed = static_cast(a.raw()); + a_packed = (a_packed | (a_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_packed) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[0]) + ); + } + + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c); + } + #endif + + return result; + } +}; + + +/// bit_and +template +struct bit_and> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] & b_data[i]); + } + + return result; + } +}; + + +/// bit_or +template +struct bit_or> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] | b_data[i]); + } + + return result; + } +}; + + +/// bit_not +template +struct bit_not> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (~a_data[i]); + } + + return result; + } +}; + +/// bit_xor +template +struct bit_xor> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] ^ b_data[i]); + } + + return result; + } +}; + +/// Fused and-popc-add +template +struct and_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + and_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + + +/// Fused or-popc-add +template +struct or_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + or_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused xor-popc-add +template +struct xor_popc_add, Array, Array> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + Array result; + xor_popc_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operator overloads +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(T const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, T const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs, Array const &rhs) { + minus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs) { + negate> op; + return op(lhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(T lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, T rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator/(Array const &lhs, Array const &rhs) { + divides> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(T a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, T b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, T c) { + multiply_add> op; + return op(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// AlignedArray +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Aligned array type +template < + /// Element type + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = ( sizeof_bits::value * N + 7 ) / 8 +> +class alignas(Alignment) AlignedArray: public Array { +public: + +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/array_subbyte.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..0bd9d0d7f7dc709b951c6979a3e26cf05ba9c79d --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_planar_complex.h @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Array holding planar complex elements +template +struct ArrayPlanarComplex { + + /// Underlying real element + using Element = Element_; + + /// Number of logical elements + static constexpr size_t kElements = N; + + /// Underlying Fragment of real-valued elemenets + using ArrayReal = cutlass::Array; + +public: + /// Fragment of real-valued elements representing the real part + ArrayReal real; + + /// Fragment of real-valued elements representing the imaginary part + ArrayReal imag; + +public: + /// Sets the array to zero efficiently + CUTLASS_HOST_DEVICE + void clear() { + real.clear(); + imag.clear(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to deduce template arguments +template +CUTLASS_HOST_DEVICE +ArrayPlanarComplex +make_ArrayPlanarComplex(Array const &real, Array const &imag) { + return ArrayPlanarComplex{real, imag}; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h new file mode 100644 index 0000000000000000000000000000000000000000..756890bb61f7ff5f2a9912b00b98a54deae6ee75 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/array_subbyte.h @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically sized array for any data type +template < + typename T, + int N +> +struct Array { + static constexpr int kSizeBits = sizeof_bits::value * N; + + /// Storage type + using Storage = typename platform::conditional< + ((kSizeBits % 32) != 0), + typename platform::conditional< + ((kSizeBits % 16) != 0), + uint8_t, + uint16_t + >::type, + uint32_t + >::type; + + /// Element type + using Element = T; + + /// Number of logical elements per stored object + static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; + + /// Number of storage elements + static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + + /// Number of logical elements + static constexpr size_t kElements = N; + + /// Bitmask for covering one item + static constexpr Storage kMask = ((Storage(1) << sizeof_bits::value) - 1); + + // + // C++ standard members with pointer types removed + // + + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + typedef value_type *pointer; + typedef value_type const *const_pointer; + + // + // References + // + + /// Reference object inserts or extracts sub-byte items + class reference { + /// Pointer to storage element + Storage *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + reference() = default; + + /// Ctor + CUTLASS_HOST_DEVICE + reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + /// Assignment + CUTLASS_HOST_DEVICE + reference &operator=(T x) { + // `*ptr_ & kUpdateMask` will read ptr_ before write to it + // This means code pattern like + // + // ```cpp + // Array result; + // result[0] = xxx; + // ``` + // + // Will leads to compiler warning on use of uninitialized member variable. Although we know + // this read of uninitialized member variable is harmeless. + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wuninitialized" +#elif defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuninitialized" +# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + + Storage item = (reinterpret_cast(x) & kMask); + + Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); + + *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); + +#if defined(__clang__) +# pragma clang diagnostic pop +#elif defined(__GNUC__) +# pragma GCC diagnostic pop +#endif + + return *this; + } + + CUTLASS_HOST_DEVICE + T get() const { + Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Extract + CUTLASS_HOST_DEVICE + operator T() const { + return get(); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + }; + + /// Reference object extracts sub-byte items + class const_reference { + + /// Pointer to storage element + Storage const *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + const_reference() = default; + + /// Ctor + CUTLASS_HOST_DEVICE + const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + const T get() const { + Storage item = (*ptr_ >> (idx_ * sizeof_bits::value)) & kMask; + return reinterpret_cast(item); + } + + /// Extract + CUTLASS_HOST_DEVICE + operator T() const { + Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + }; + + // + // Iterators + // + + /// Bidirectional iterator over elements + class iterator { + + /// Pointer to storage element + Storage *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + iterator() = default; + + CUTLASS_HOST_DEVICE + iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return ret; + } + + CUTLASS_HOST_DEVICE + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return !(*this == other); + } + }; + + /// Bidirectional constant iterator over elements + class const_iterator { + + /// Pointer to storage element + Storage const *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + const_iterator() = default; + + CUTLASS_HOST_DEVICE + const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + + CUTLASS_HOST_DEVICE + iterator &operator++() { + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator &operator--() { + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return *this; + } + + CUTLASS_HOST_DEVICE + iterator operator++(int) { + iterator ret(*this); + ++idx_; + if (idx_ == kElementsPerStoredItem) { + ++ptr_; + idx_ = 0; + } + return ret; + } + + CUTLASS_HOST_DEVICE + iterator operator--(int) { + iterator ret(*this); + if (!idx_) { + --ptr_; + idx_ = kElementsPerStoredItem - 1; + } + else { + --idx_; + } + return ret; + } + + CUTLASS_HOST_DEVICE + const_reference operator*() const { + return const_reference(ptr_, idx_); + } + + CUTLASS_HOST_DEVICE + bool operator==(iterator const &other) const { + return ptr_ == other.ptr_ && idx_ == other.idx_; + } + + CUTLASS_HOST_DEVICE + bool operator!=(iterator const &other) const { + return !(*this == other); + } + }; + + /// Bidirectional iterator over elements + class reverse_iterator { + + /// Pointer to storage element + Storage *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + reverse_iterator() = default; + + CUTLASS_HOST_DEVICE + reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + }; + + /// Bidirectional constant iterator over elements + class const_reverse_iterator { + + /// Pointer to storage element + Storage const *ptr_{nullptr}; + + /// Index into elements packed into Storage object + int idx_{0}; + + public: + + const_reverse_iterator() = default; + + CUTLASS_HOST_DEVICE + const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + }; + + /// Efficient clear method + CUTLASS_HOST_DEVICE + void clear() { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(kStorageElements); ++i) { + storage[i] = Storage(0); + } + } + + CUTLASS_HOST_DEVICE + reference at(size_type pos) { + return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTLASS_HOST_DEVICE + const_reference at(size_type pos) const { + return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + } + + CUTLASS_HOST_DEVICE + reference operator[](size_type pos) { + return at(pos); + } + + CUTLASS_HOST_DEVICE + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTLASS_HOST_DEVICE + reference front() { + return at(0); + } + + CUTLASS_HOST_DEVICE + const_reference front() const { + return at(0); + } + + CUTLASS_HOST_DEVICE + reference back() { + return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTLASS_HOST_DEVICE + const_reference back() const { + return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + } + + CUTLASS_HOST_DEVICE + pointer data() { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + const_pointer data() const { + return reinterpret_cast(storage); + } + + CUTLASS_HOST_DEVICE + Storage * raw_data() { + return storage; + } + + CUTLASS_HOST_DEVICE + Storage const * raw_data() const { + return storage; + } + + CUTLASS_HOST_DEVICE + constexpr bool empty() const { + return !kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + constexpr size_type max_size() const { + return kElements; + } + + CUTLASS_HOST_DEVICE + void fill(T const &value) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerStoredItem; ++i) { + reference ref(storage, i); + ref = value; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kStorageElements; ++i) { + storage[i] = storage[0]; + } + } + + CUTLASS_HOST_DEVICE + iterator begin() { + return iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_iterator cbegin() const { + return const_iterator(storage); + } + + CUTLASS_HOST_DEVICE + iterator end() { + return iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + const_iterator cend() const { + return const_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rbegin() { + return reverse_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crbegin() const { + return const_reverse_iterator(storage + kStorageElements); + } + + CUTLASS_HOST_DEVICE + reverse_iterator rend() { + return reverse_iterator(storage); + } + + CUTLASS_HOST_DEVICE + const_reverse_iterator crend() const { + return const_reverse_iterator(storage); + } + +private: + /// Internal storage + Storage storage[kStorageElements]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h new file mode 100644 index 0000000000000000000000000000000000000000..8919e992af20ac2d7f2b5daa8a0cbd7a6f7b79e5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/barrier.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// +// Utilities for abstracting synchronization methods for barriers +// + +struct SyncthreadsSync { + CUTLASS_DEVICE + static void sync() { + __syncthreads(); + } +}; + +struct SyncwarpSync { + CUTLASS_DEVICE + static void sync() { + __syncwarp(); + } +}; + +template < + int ThreadCount, + int BarrierId +> +struct NamedBarrierSync { + CUTLASS_DEVICE + static void sync() { + cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast(BarrierId)); + } +}; + +} // namepspace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Group or CTA-wide semaphore for inter-CTA synchronization. +template +struct GenericBarrier { + +public: + + /// Flag type + using T = int; + + /// Initial flag value + static const T INIT = 0; + + +protected: + + /// Load flag, as a strong acquire operation (int specialization) + CUTLASS_DEVICE + static int ld_acquire(int *ptr) + { + int state = 0; + +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Acquire pattern using acquire modifier + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + +#else + asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif // (__CUDA_ARCH__ >= 700) + + return state; + } + + + /// Reduce into flag, with release pattern (int specialization) + CUTLASS_DEVICE + static void red_release(int *ptr, int val) + { +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data + // that was weakly-written by other threads prior to the last syncthreads) + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); + +#else + __threadfence(); + atomicAdd(ptr, val); +#endif // (__CUDA_ARCH__ >= 700) + } + + +public: + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) < count) {} + } + + Sync::sync(); + } + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) != val) {} + } + Sync::sync(); + } + + /// Uses thread[0] to wait for the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(atomicCAS(flag_ptr, val, 0) != val) {} + } + + Sync::sync(); + } + + /// Increment the arrival count for a flag + CUTLASS_DEVICE + static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1) + { + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + Sync::sync(); + + if (thread_idx == 0) + { + red_release(flag_ptr, val); + } + } + + + /// Increment the arrival counts for a range of flags + CUTLASS_DEVICE + static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) + { + int flag_idx = first_flag_idx + thread_idx; + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + // Barrier to make sure all other threads in group have written their data + Sync::sync(); + + // Select threads increment their flags + if (thread_idx < count) { + red_release(flag_ptr, val); + } + } +}; + +using Barrier = GenericBarrier; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing + * runtime index values to be used to call into named barriers with compile-time-constant IDs. + * + * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID + * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into + * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type +**/ +template < + uint32_t ThreadCount_, + uint32_t Offset = 0, + uint32_t MaxNumNamedBarriers = 16 +> +struct NamedBarrierManager { + + static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers); + static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + template + using BarrierSync = cutlass::GenericBarrier>; + + // Underlying type used by all barriers for synchronization. Does not depend on + // template parameter BarrierId, so passing in 0 suffices. + using T = typename BarrierSync<0>::T; + + using IntegerSequence = cute::make_integer_sequence; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) { + wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{}); + } + +private: + CUTLASS_DEVICE + static void + check_barrier_in_range([[maybe_unused]] uint32_t idx) { + assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count"); + } + + template + CUTLASS_DEVICE + static void + wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence) { + check_barrier_in_range(idx); + if constexpr (Reset) { + ((Idx == idx && (BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + else { + ((Idx == idx && (BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + } + + template + CUTLASS_DEVICE + static void + arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads) + * via an API that mirrors that of NamedBarrierManager + * + * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization +**/ +template < + class Synchronizer, + uint32_t ThreadCount_ +> +struct SyncManager { + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + using BarrierSync = cutlass::GenericBarrier; + + // Underlying type used by all barriers for synchronization. + using T = typename BarrierSync::T; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { + BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h new file mode 100644 index 0000000000000000000000000000000000000000..5e2f40b1c85e24eb2bdeedb191529d53539f050c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/bfloat16.h @@ -0,0 +1,679 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing non-standard 16-bit floating point values with + 8 bits of exponent and 7 bit of mantissa. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else +#include +#include +#include +#include +#endif + +#include +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Floating-point type with 8 bits of exponent and 7 bits of mantissa. +struct alignas(2) bfloat16_t { + + // + // Data members + // + + /// Storage type + uint16_t storage; + + // + // Methods + // + + /// Constructs from an unsigned short + CUTLASS_HOST_DEVICE + static bfloat16_t bitcast(uint16_t x) { + bfloat16_t h; + h.storage = x; + return h; + } + +private: + struct from_32_bit_integer_t {}; + static constexpr from_32_bit_integer_t from_32_bit_integer{}; + + template + CUTLASS_HOST_DEVICE + explicit bfloat16_t(from_32_bit_integer_t, T x) { + static_assert(cutlass::platform::is_integral::value && sizeof(T) == 4, "Requires 32-bit integer"); + + float flt = static_cast(x); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(flt); + #else + std::memcpy(&bits, &flt, sizeof(bits)); + #endif + + storage = uint16_t(bits >> 16); + } + +public: + /// Default constructor + bfloat16_t() = default; + + /// Reinterpret cast from CUDA's __nv_bfloat16 type + CUTLASS_HOST_DEVICE + explicit bfloat16_t(__nv_bfloat16 const & x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + __nv_bfloat16_raw raw(x); + std::memcpy(&storage, &raw.x, sizeof(storage)); + #endif + } + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(float x) { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11) + + asm("cvt.rn.bf16.f32 %0, %1;\n" : "=h"(storage) : "f"(x)); + + #else + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(x); + #else + std::memcpy(&bits, &x, sizeof(bits)); + #endif + + if ((bits & 0x7f800000) != 0x7f800000) { + + bool mantissa_bit = ((bits & (1 << 16)) != 0); + bool round_bit = ((bits & (1 << 15)) != 0); + bool sticky_bit = ((bits & ((1 << 15) - 1)) != 0); + + if ((round_bit && sticky_bit) || (round_bit && mantissa_bit)) { + bits += uint32_t(1 << 16); + } + } + else if (bits & ~0xff800000) { + bits = 0x7fffffff; + } + + storage = uint16_t((bits >> 16) & 0xffff); + #endif + } + + /// Floating-point conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(double x): bfloat16_t(float(x)) { + + } + + /// Integer conversion - round toward nearest + CUTLASS_HOST_DEVICE + explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {} + + CUTLASS_HOST_DEVICE + explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {} + + /// Converts to float + CUTLASS_HOST_DEVICE + operator float() const { + unsigned bits = (unsigned(storage) << 16); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif + } + + /// Converts to float + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + return (float(*this) != 0.0f); + } + + /// Bitcasts to CUDA's bf16 type + CUTLASS_DEVICE + __nv_bfloat16 to_nv_bfloat16() const { + return reinterpret_cast<__nv_bfloat16 const &>(storage); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint16_t raw() const { + return storage; + } + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x8000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 7) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7f); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::bfloat16_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { + return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t nan_bf16(const char*) { + // NVIDIA canonical NaN + return cutlass::bfloat16_t::bitcast(0x7fff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::bfloat16_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::bfloat16_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::bfloat16_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t sqrt(cutlass::bfloat16_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::bfloat16_t(sqrtf(float(h))); +#else + return cutlass::bfloat16_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { + + uint16_t a_bits; + uint16_t b_bits; + + #if defined(__CUDA_ARCH__) + a_bits = reinterpret_cast(a); + b_bits = reinterpret_cast(b); + #else + std::memcpy(&a_bits, &a, sizeof(a_bits)); + std::memcpy(&b_bits, &b, sizeof(b_bits)); + #endif + + uint16_t a_mag = (a_bits & 0x7fff); + uint16_t b_sign = (b_bits & 0x8000); + uint16_t result = (a_mag | b_sign); + + return bfloat16_t::bitcast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +namespace std { + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; + +} // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; + +} // namespace platform +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) == float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) != float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) < float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) <= float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) > float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else + return float(lhs) >= float(rhs); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + return bfloat16_t(float(lhs) + float(rhs)); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hneg(lhs.to_nv_bfloat16())); +#else + return bfloat16_t(-float(lhs)); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + return bfloat16_t(float(lhs) - float(rhs)); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + return bfloat16_t(float(lhs) * float(rhs)); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + return bfloat16_t(float(lhs) / float(rhs)); +#endif +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + lhs = bfloat16_t(float(lhs) + float(rhs)); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + lhs = bfloat16_t(float(lhs) - float(rhs)); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + lhs = bfloat16_t(float(lhs) * float(rhs)); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else + lhs = bfloat16_t(float(lhs) / float(rhs)); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator++(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else + float tmp(lhs); + ++tmp; + lhs = bfloat16_t(tmp); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t& operator--(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else + float tmp(lhs); + --tmp; + lhs = bfloat16_t(tmp); +#endif + return lhs; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator++(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else + float tmp(lhs); + tmp++; + lhs = bfloat16_t(tmp); +#endif + return ret; +} + +CUTLASS_HOST_DEVICE +bfloat16_t operator--(bfloat16_t & lhs, int) { + bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else + float tmp(lhs); + tmp--; + lhs = bfloat16_t(tmp); +#endif + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(long double x) { + return cutlass::bfloat16_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::bfloat16_t operator "" _bf16(unsigned long long int x) { + return cutlass::bfloat16_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h new file mode 100644 index 0000000000000000000000000000000000000000..8788f18b99d5c9d700a0f6f28625097f41862c74 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic include for CUTLASS BLAS3/HPC code. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/blas3_types.h" +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines FillMode inversions +template +struct InvertFillMode; + +/// Invert FillMode lower to upper +template <> +struct InvertFillMode { + static FillMode const mode = FillMode::kUpper; +}; + +/// Invert FillMode upper to lower +template <> +struct InvertFillMode { + static FillMode const mode = FillMode::kLower; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines SideMode inversions +template +struct InvertSideMode; + +/// Invert SideMode left to right +template <> +struct InvertSideMode { + static SideMode const mode = SideMode::kRight; +}; + +/// Invert SideMode right to left +template <> +struct InvertSideMode { + static SideMode const mode = SideMode::kLeft; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines correct compare operation for Triangular matrix boundary +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater_equal, + less_equal>::type; +}; + +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater_equal, + less_equal>::type; +}; + +template +struct TrMatrixCompareOp { + using Index = int32_t; + using Type = typename platform::conditional< + (kFillMode == FillMode::kLower), + greater, + less>::type; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Returns precision in terms of bits (based on datatype) to fill tensors with. +// Defaults to 5 bits of mantissa for TF32 and FP32 (with implicit round-offs). +// Also defines acceptable mantissa result variance/error. +template +struct MantissaInBits { + static int constexpr bits = 5; + static double constexpr error = 1.0e-7; +}; + +// Full precision is supported for FP64 +template <> +struct MantissaInBits { + static int constexpr bits = 30; + static double constexpr error = 1.0e-15; +}; + +template <> +struct MantissaInBits> { + static int constexpr bits = 30; + static double constexpr error = 1.0e-14; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h new file mode 100644 index 0000000000000000000000000000000000000000..e47002b1a7255478f3a8d08518a4e081cbfd2422 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/blas3_types.h @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing the type of kernel (based on input or output matrices). +enum class BlasMode { + kGemm, + kSymmetric, + kHermitian, + kTriangular, + kInvalid +}; + +/// Enumerated type describing the fill mode for matrices for BLAS functions. +enum class FillMode { + kFull, /// The entire tensor is covered. + kLower, /// The 'lower' part of a tensor is covered including diagonal + kUpper, /// The 'upper' part of a tensor is covered including diaognal + kDiagonal, /// Only diagonal elements are covered. + kNone, /// No element is covered. + kInvalid +}; + +/// Enumerated type describing the diagonal property of matrices for BLAS functions. +enum class DiagType { + kNonUnit, + kUnit, + kZero, // Only used internally for computing SYMM/HEMM + kInvalid +}; + +/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. +enum class SideMode { + kLeft, + kRight, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h new file mode 100644 index 0000000000000000000000000000000000000000..93665c64047d847a6fc9de3f5ec691caa8186dbc --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/block_striped.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, + statically-sized array types to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/wmma_array.h" +#include "cutlass/functional.h" +#include "cutlass/complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// AccessWidth +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit +template < + typename T, + int Limit> +struct AccessWidth +{ + // Inductive case + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes, /// Template induction variable + bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes + ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> + struct Detail + { + static const int value = Detail::value; + }; + + // Base case (ObjectBytes is not an even multiple of AlignBytes) + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes> /// Template induction variable + struct Detail + { + static const int value = AlignBytes / 2; + }; + + /// The maximal power-of-two that evenly divides the size of T + static const int value = Detail< + (int) sizeof(T), + 1>::value; +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// StripedAccessType +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Default specialization. Striping granularity is type T.) +template < + typename T, /// Data type + int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) + AccessWidth::value> +struct alignas(TransferBytes) StripedAccessType : public T +{}; + + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) +template < + typename T, /// Array element type + int N, /// Number of elements in array + bool RegisterSized, /// T is register-sized + int TransferBytes> /// Data access width +struct StripedAccessType< + Array, + TransferBytes> +: public AlignedArray< + T, // Element type of StripedAccessType + __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType + TransferBytes> // Alignment of StripedAccessType +{}; + + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) +template< + typename Use, + int m, + int n, + int k, + typename ElementT, + typename Layout, + int kFragments, + int TransferBytes> +struct StripedAccessType< + WmmaFragmentArray, kFragments>, + TransferBytes> +: public AlignedArray< + ElementT, + __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), + TransferBytes> +{}; + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStriped +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Utility for performing block-striped access (load, store) of trivially-copyable, +/// statically-sized array types to global memory +template < + int BlockThreads, + typename ArrayT, + typename AccessT = StripedAccessType > +struct BlockStriped +{ + /// Number of striped accesses + static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); + static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); + + /// Load + CUTLASS_DEVICE + static void load(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_data[i] = access_input[(BlockThreads * i) + thread_idx]; + } + } + + /// Load & Add + CUTLASS_DEVICE + static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + plus add; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) + { + access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); + } + } + + /// Store + CUTLASS_DEVICE + static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + AccessT *access_output = reinterpret_cast(ptr); + const AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_output[(BlockThreads * i) + thread_idx] = access_data[i]; + } + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStripedReduce +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Default specialization) +template < + int BlockThreads, + typename ArrayT, + typename ElementT = typename StripedAccessType::Element> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + ElementT> +{ + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + ElementT *access_output = reinterpret_cast(ptr); + const ElementT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Specialization for half_t. Uses half2 vectorized-reduction.) +template < + int BlockThreads, + typename ArrayT> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + half2> +{ + static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); + + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + half2 *access_output = reinterpret_cast(ptr); + const half2 *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) + { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..22c17dba702f62eeab80ab5b3399bda269f4f4d2 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cluster_launch.hpp @@ -0,0 +1,394 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUDA interfaces to launch CUTLASS device-level operators (for >= SM90) that use thread-block clusters. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" +#include +#include "cutlass/arch/synclog.hpp" + +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(type_traits) +#else +#include +#include +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) + # define CUDA_ENABLE_PREFERRED_CLUSTER +#endif +namespace cutlass { + +#ifndef NDEBUG +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + fprintf(stderr, \ + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(cudaError_t_status)); \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#else +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#endif + +struct ClusterLauncher { + constexpr static int MaxClusterSize = 32; + + struct LaunchConfig { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + cudaLaunchConfig_t launch_config; + + #if defined(CUDA_ENABLE_PREFERRED_CLUSTER) + constexpr static int numAttrs = 3; + #else + + constexpr static int numAttrs = 2; + #endif + cudaLaunchAttribute launch_attribute[numAttrs]; + // Commonly used utility functions + dim3 gridDim() { return launch_config.gridDim; } + dim3 blockDim() { return launch_config.blockDim; } +#endif + }; + + // Check for hardware compatibility + static inline CUTLASS_HOST + Status check_cluster_dims(dim3 grid, dim3 cluster) { + if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && + (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); + return Status::kInvalid; + } + } + + static inline CUTLASS_HOST + Status +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + init(void const* kernel_function) +#else + init(void const* /* kernel_function */) +#endif + { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (kernel_function == nullptr) { + CUTLASS_TRACE_HOST("kernel_function is null"); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute"); + cudaError_t prevStatus = cudaGetLastError(); + if (prevStatus != cudaSuccess) { + fprintf(stderr, + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", + __FILE__, + __LINE__, + cudaGetErrorString(prevStatus)); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute"); +#endif + // This attribute was added in CUDA 11.8. + cudaError_t status = + cudaFuncSetAttribute( + kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); + Return_Status(status); +#else + return Status::kInvalid; +#endif + } + + static inline CUTLASS_HOST + LaunchConfig make_cluster_launch_config( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size = 0, + cudaStream_t cuda_stream = 0, + bool launch_with_pdl = false + , dim3 const fallback_cluster_dims = {0, 0, 0} + ) { + LaunchConfig cluster_launch_config; +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + auto &launch_config = cluster_launch_config.launch_config; + auto &launch_attribute = cluster_launch_config.launch_attribute; + auto numAttrs = cluster_launch_config.numAttrs; + + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + + bool have_fallback = fallback_cluster_dims.x * fallback_cluster_dims.y * fallback_cluster_dims.z > 0; + + if (have_fallback) { + launch_attribute[0].val.clusterDim = {fallback_cluster_dims.x, fallback_cluster_dims.y, fallback_cluster_dims.z}; + CUTLASS_TRACE_HOST("ClusterLauncher: Setting fallback ClusterDims = " + "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); + } + else { + + launch_attribute[0].val.clusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; + CUTLASS_TRACE_HOST("ClusterLauncher: Setting ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + } + +#if defined(CUDA_ENABLE_PREFERRED_CLUSTER) + if (have_fallback) { + if (cute::initialize_preferred_cluster_launch(nullptr, grid_dims, cluster_dims, fallback_cluster_dims)) { + launch_attribute[1].id = cudaLaunchAttributePreferredClusterDimension; + launch_attribute[1].val.preferredClusterDim = {cluster_dims.x, cluster_dims.y, cluster_dims.z}; + CUTLASS_TRACE_HOST("ClusterLauncher: Setting preferred ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + } + } + else { + numAttrs--; + } +#endif + + + // PDL attributes + launch_attribute[numAttrs - 1].id = cudaLaunchAttributeProgrammaticStreamSerialization; + launch_attribute[numAttrs - 1].val.programmaticStreamSerializationAllowed = 1; + + launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; + launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; + launch_config.dynamicSmemBytes = smem_size; + launch_config.stream = cuda_stream; + launch_config.numAttrs = launch_with_pdl ? numAttrs : numAttrs - 1; + launch_config.attrs = launch_attribute; + return cluster_launch_config; +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return cluster_launch_config; +#endif + } + + // This is the method we expect to use going forward + static inline CUTLASS_HOST + Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void const* kernel, + void** kernel_params, + bool launch_with_pdl = false) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, cluster_dims, + block_dims, smem_size, cuda_stream, launch_with_pdl); + + auto launch_grid_dims = cluster_launch_config.gridDim(); + if (check_cluster_dims(launch_grid_dims, cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching GridDims = " + "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " + "And ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + cutlass::arch::synclog_setup(); + cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } + + + // This is the method we expect to use going forward + // Launch a preferred cluster grid + static inline CUTLASS_HOST + Status launch_with_fallback_cluster( + dim3 const grid_dims, + dim3 const preferred_cluster_dims, + dim3 const fallback_cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void const* kernel, + void** kernel_params, + bool launch_with_pdl = false) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + LaunchConfig cluster_launch_config = make_cluster_launch_config(grid_dims, preferred_cluster_dims, + block_dims, smem_size, cuda_stream, launch_with_pdl, fallback_cluster_dims); + + auto launch_grid_dims = cluster_launch_config.gridDim(); + if (check_cluster_dims(launch_grid_dims, preferred_cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching \n\tGridDims = " + "(" << launch_grid_dims.x << ", " << launch_grid_dims.y << ", " << launch_grid_dims.z << "), " + "\n\tPreferred ClusterDims = " + "(" << preferred_cluster_dims.x << ", " << preferred_cluster_dims.y << ", " << preferred_cluster_dims.z << ")," + "\n\tFallback ClusterDims = " + "(" << fallback_cluster_dims.x << ", " << fallback_cluster_dims.y << ", " << fallback_cluster_dims.z << ")\n"); + + cutlass::arch::synclog_setup(); + cudaError_t status = cudaLaunchKernelExC(&cluster_launch_config.launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } + + +}; + +namespace detail { + +template +void* checked_addressof(Arg&& arg) { + static_assert(! std::is_rvalue_reference_v || ! std::is_const_v, "You cannot take the address of a const rvalue reference (const T&&)."); + // We use std::addressof to ensure we get the address, + // in case the type has an overloaded operator&. + // Note that this precludes `const T&&` references. + return const_cast(reinterpret_cast(std::addressof(arg))); +} + +} // namespace detail + +//! Parameters for launch_on_cluster (see below). +struct ClusterLaunchParams { + //! Grid dimensions + dim3 grid_dims{1, 1, 1}; + + //! Block dimensions + dim3 block_dims{1, 1, 1}; + + //! Cluster dimensions + dim3 cluster_dims{1, 1, 1}; + + //! Number of bytes required for the kernel's shared memory. + int smem_size_in_bytes = 0; + + //! CUDA stream on which to launch the kernel. + cudaStream_t cuda_stream = nullptr; +}; + +/// @brief Launch the kernel on the stream using cluster launch. +/// +/// @param params Cluster launch parameters (see above). +/// @param kernel_ptr Pointer to the kernel function (see example). +/// @param args Zero or more arguments to pass to the kernel. +/// +/// @tparam Args Types of the arguments passed to the kernel. +/// Don't specify this/these template argument(s) explicitly. +/// +/// @return Status::Success on success, else an error code. +/// +/// @code +/// template +/// __global__ void kernel(A a, B b, C c); +/// +/// X x = get_x(); +/// Y y = get_y(); +/// Z z = get_z(); +/// +/// void const* kernel_ptr = +/// const_cast(reinterpret_cast( +/// &kernel)); +/// auto status = launch_kernel_on_cluster( +/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, +/// kernel_ptr, x, y, z); +/// @endcode +template +CUTLASS_HOST cutlass::Status +launch_kernel_on_cluster(const ClusterLaunchParams& params, + void const* kernel_ptr, + Args&& ... args) +{ + // Unfortunately, we find ourselves needing to pass in + // the parameters as an array of raw pointers. + if constexpr (sizeof...(Args) == 0) { + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, nullptr); + } + else { + void* kernel_params[sizeof...(Args)] = { + detail::checked_addressof(std::forward(args))... + }; + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, + kernel_params); + } +} + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h new file mode 100644 index 0000000000000000000000000000000000000000..0287850bc6febe16a90695c82fabee566cdf9a82 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/complex.h @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#endif +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" +#include "cutlass/real.h" + +#include "cutlass/numeric_types.h" + +#include "cutlass/fast_math.h" + +#if !defined(__CUDACC_RTC__) +#include +#endif + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Enumeraed type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines ComplexTransform inversions +template +struct InvertComplexTransform; + +/// Invert ComplexTransform from kNone to kConjugate +template <> +struct InvertComplexTransform { + static ComplexTransform const transform = ComplexTransform::kConjugate; +}; + +/// Invert ComplexTransform from kConjugate to kNone +template <> +struct InvertComplexTransform { + static ComplexTransform const transform = ComplexTransform::kNone; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Accessors for CUDA complex types +// + +#if !defined(__CUDACC_RTC__) +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +float const &real(cuFloatComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +float &real(cuFloatComplex &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +double const &real(cuDoubleComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +CUTLASS_HOST_DEVICE +double &real(cuDoubleComplex &z) { return z.x; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +float const &imag(cuFloatComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +float &imag(cuFloatComplex &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +double const &imag(cuDoubleComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +CUTLASS_HOST_DEVICE +double &imag(cuDoubleComplex &z) { return z.y; } + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuFloatComplex +conj(cuFloatComplex const& z) { + return make_cuFloatComplex(z.x, -z.y); +} + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuDoubleComplex +conj(cuDoubleComplex const& z) { + return make_cuDoubleComplex(z.x, -z.y); +} +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Class for representing and manipulating complex numbers with conversions from built-in CUDA +/// complex types. + +template +class complex +{ + public: + /// Type alias for scalar type + using value_type = T; + + private: + // + // Data members + // + + /// Real part + T _real; + + /// Imaginary part + T _imag; + + public: + +// +// Methods +// + + /// Default constructor + complex() = default; + + /// Constructor + CUTLASS_HOST_DEVICE + complex(T r) : _real(r), _imag(T(0)) {} + + /// Constructor + CUTLASS_HOST_DEVICE + complex(T r, T i) : _real(r), _imag(i) {} + + /// Constructor + template + CUTLASS_HOST_DEVICE + complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} + + + #if !defined(__CUDACC_RTC__) + /// Conversion from cuFloatComplex + CUTLASS_HOST_DEVICE + complex(cuFloatComplex const &z) : _real(static_cast(cuCrealf(z))), _imag(static_cast(cuCimagf(z))) {} + + /// Conversion from cuDoubleComplex + CUTLASS_HOST_DEVICE + complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} + #endif + + /// Equality operator + CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { + return this->real() == rhs.real() && this->imag() == rhs.imag(); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE bool operator!=(complex const &rhs) const { + return !(*this == rhs); + } + + /// Addition + template + CUTLASS_HOST_DEVICE complex operator+(complex const &rhs) const { + return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); + } + + /// Reduction into memory address. Components may update out of order. + template + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + cutlass::atomic_add reduce; + reduce(&ptr->_real, _real); + reduce(&ptr->_imag, _imag); + } + + /// Reduction into memory address. Components may update out of order. (Half specialization) + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + half2 *h2_ptr = reinterpret_cast(ptr); + half2 h2_data = reinterpret_cast(*this); + cutlass::atomic_add reduce; + reduce(h2_ptr, h2_data); + } + + /// Subtraction + template + CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { + return complex(this->real() - rhs.real(), this->imag() - rhs.imag()); + } + + /// Multiplication + template + CUTLASS_HOST_DEVICE complex operator*(complex const &rhs) const { + return complex(this->real() * rhs.real() - this->imag() * rhs.imag(), + this->real() * rhs.imag() + this->imag() * rhs.real()); + } + + /// Scalar Multiplication + template + CUTLASS_HOST_DEVICE complex operator*(A const &s) const { + return complex(this->real() * s, this->imag() * s); + } + + /// Division + template + CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { + T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); + + return complex( + (real() * rhs.real() + imag() * rhs.imag()) / d, + (imag() * rhs.real() - real() * rhs.imag()) / d + ); + } + + /// Scalar Division + template + CUTLASS_HOST_DEVICE complex operator/(A const &s) const { + return complex(this->real() / s, this->imag() / s); + } + + /// Addition + template + CUTLASS_HOST_DEVICE complex &operator+=(complex const &rhs) { + *this = *this + rhs; + return *this; + } + + /// Subtraction + template + CUTLASS_HOST_DEVICE complex &operator-=(complex const &rhs) { + *this = *this - rhs; + return *this; + } + + /// Multiplication + template + CUTLASS_HOST_DEVICE complex &operator*=(complex const &rhs) { + *this = *this * rhs; + return *this; + } + + /// Scalar multiplication + template + CUTLASS_HOST_DEVICE complex &operator*=(A s) { + *this = *this * s; + return *this; + } + + /// Division + template + CUTLASS_HOST_DEVICE complex &operator/=(complex const &rhs) { + *this = *this / rhs; + return *this; + } + + /// Accesses the real part of the complex number + CUTLASS_HOST_DEVICE + T const &real() const { return _real; } + + /// Accesses the real part of the complex number + CUTLASS_HOST_DEVICE + T &real() { return _real; } + + /// Accesses the imaginary part of the complex number + CUTLASS_HOST_DEVICE + T const &imag() const { return _imag; } + + /// Accesses the imaginary part of the complex number + CUTLASS_HOST_DEVICE + T &imag() { return _imag; } + + /// Set the real part of the complex number + CUTLASS_HOST_DEVICE + void real(T real) { _real = real; } + + /// Set the imaginary part of the complex number + CUTLASS_HOST_DEVICE + void imag(T imag) { _imag = imag; } + + #if !defined(__CUDACC_RTC__) + /// Converts to cuFloatComplex + CUTLASS_HOST_DEVICE + explicit operator cuFloatComplex() const { return make_cuFloatComplex(float(real()), float(imag())); } + + /// Converts to cuDoubleComplex + CUTLASS_HOST_DEVICE + explicit operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } + #endif +}; + +// Complex conjugate +template +CUTLASS_HOST_DEVICE complex conj(complex const& z) { + return {z.real(), -z.imag()}; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Accessors for complex template +// + +// Nonmember real and imag need to work for non-complex numbers too. +// That means cutlass::complex, std::complex, cuda::std::complex, and +// any user-defined complex number type that looks like std::complex. +// It's reasonable to assume that a "complex number type" has +// zero-argument real() and imag() member functions returning +// non-void. While cuFloatComplex and cuDoubleComplex lack those +// member functions, one-argument nonmember real and imag overloads +// for those types are defined above. + +namespace detail { + +template +struct has_zero_argument_real_member_function : + cutlass::platform::false_type +{}; + +template +struct has_zero_argument_real_member_function().real()) + > + > +> : cutlass::platform::true_type +{}; + +template +constexpr bool has_zero_argument_real_member_function_v = + has_zero_argument_real_member_function::value; + +template +struct has_zero_argument_imag_member_function : + cutlass::platform::false_type +{}; + +template +struct has_zero_argument_imag_member_function().imag()) + > + > +> : cutlass::platform::true_type +{}; + +template +constexpr bool has_zero_argument_imag_member_function_v = + has_zero_argument_imag_member_function::value; + +} // namespace detail + +template +CUTLASS_HOST_DEVICE auto real(T z) { + if constexpr (detail::has_zero_argument_real_member_function_v) { + return z.real(); + } else { + return z; + } +} + +template +CUTLASS_HOST_DEVICE auto imag(T z) { + if constexpr (detail::has_zero_argument_imag_member_function_v) { + return z.imag(); + } else { + // Imaginary part of a non-complex input has the same type as the + // input, and its value is zero. CUTLASS assumes in this case + // that value-initializing T is well-formed and results in zero. + return T{}; + } +} + +// +// Output operators +// + +#if !defined(__CUDACC_RTC__) +template +std::ostream &operator<<(std::ostream &out, complex const &z) { + T _r = real(z); + T _i = imag(z); + + if (bool(_i)) { + return out << _r << "+i" << _i; + } + return out << _r; +} +#endif + +// +// Non-member operators defined for complex types +// + + +// +// Non-member functions defined for complex numbers +// + +// abs returns the magnitude of the complex number. + +CUTLASS_HOST_DEVICE float abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +CUTLASS_HOST_DEVICE double abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +// In theory, it would make sense to add a complex +// specialization of abs here, since hypot works for long double too. +// In practice, long double doesn't have a portable number of bits or +// behavior, so users who care about higher-precision floating-point +// computation should probably insist on an actual FP128 type. + +template +CUTLASS_HOST_DEVICE T abs(complex const &z) { + // cutlass::complex permits all kinds of T, including types that + // don't have NaN. For a generic floating-point type with Inf + // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it + // would handle issues like avoiding unwarranted overflow if + // z.real() or z.imag() is slightly bigger than the square root of + // the max finite number. That could be a future improvement; for + // now, the code just uses the naive algorithm. + // + // Use the "swap two-step" idiom so that argument-dependent lookup + // can find any CUTLASS-specific overloads. + using cutlass::sqrt; + return sqrt(z.real() * z.real() + z.imag() * z.imag()); +} + +/// Returns the magnitude of the complex number +template +CUTLASS_HOST_DEVICE T arg(complex const &z) { + return atan2(imag(z), real(z)); +} + +/// Returns the squared magnitude of a real number +template +CUTLASS_HOST_DEVICE T norm(T const &z) { + return z * z; +} + +/// Returns the squared magnitude of a real number +template <> +CUTLASS_HOST_DEVICE int8_t norm(int8_t const &z) { + return static_cast(z * z); +} + +/// Returns the squared magnitude of a complex number +template +CUTLASS_HOST_DEVICE double norm(complex const &z) { + return real(z) * real(z) + imag(z) * imag(z); +} + +/// Norm-accumulate calculation +template +CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { + return accumulator + static_cast(x) * static_cast(x); +} + +/// Norm accumulate specialized for complex types +template +CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { + return accumulator + static_cast(real(z)) * static_cast(real(z)) + + static_cast(imag(z)) * static_cast(imag(z)); +} + +namespace detail { + +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) { + return conj(z); +} + +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) { + return z; +} + +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z) { + constexpr bool use_unqualified_conj = + ! cutlass::platform::is_arithmetic_v && + ! detail::has_cutlass_conj_v && + detail::has_unqualified_conj_v; + return conj_impl(z, cutlass::platform::bool_constant{}); +} + +} // namespace detail + +// Return the complex conjugate of the input. +// +// This MUST be a function and not a function object, because it may +// be common practice for downstream types to define specifically +// cutlass::conj overloads, instead of overloads in their namespace. +// +// As a result of this being a function and not a function object, +// CUTLASS code needs to declare "using cutlass::conj;" in scope and +// then call this function unqualified, just like std::swap. +// +// If an overload already exists for cutlass::conj(T), that overload +// will be called instead of this one. Otherwise: +// +// 1. for arithmetic types, return z; +// +// 2. for types where (namespace-unqualified) conj(z) is well formed +// and cutlass::conj(z) is NOT well formed, return conj(z); and, +// +// 3. for everything else, return z. +// +// Regarding (1), the C++ Standard Library makes std::conj always +// return std::complex, even for (noncomplex) arithmetic types. +// cutlass::conj(T t) needs to return type T. This follows the +// convention of linear algebra software like the BLAS, where +// "conjugate transpose" means the same thing as "transpose" for a +// matrix of noncomplex numbers. +// +// Case (2) covers std::complex, cuda::std::complex, and non-Standard +// (including user-defined) complex number types (for which "conj(z)" +// is findable via argument-dependent lookup, but does not live in the +// cutlass namespace). It excludes cutlass::conj(z) in order to +// prevent infinite recursion. +// +// Case (3) covers non-Standard non-complex number types. +template +CUTLASS_HOST_DEVICE T conj(T const& z) { + return detail::conj_impl(z); +} + +/// Projects the complex number z onto the Riemann sphere +template +CUTLASS_HOST_DEVICE complex proj(complex const &z) { + T d = real(z) * real(z) + imag(z) * imag(z) + T(1); + return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); +} + +/// Returns a complex number with magnitude r and phase theta +template +CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { + return complex(r * cos(theta), r * sin(theta)); +} + +/// Computes the complex exponential of z. +template +CUTLASS_HOST_DEVICE complex exp(complex const &z) { + return complex(fast_exp(real(z)) * fast_cos(imag(z)), fast_exp(real(z)) * fast_sin(imag(z))); +} + +/// Computes the log of z +template +CUTLASS_HOST_DEVICE complex log(complex const &z) { + return complex(log(abs(z)), arg(z)); +} + +/// Computes the log base 10 of z +template +CUTLASS_HOST_DEVICE complex log10(complex const &z) { + return log(z) / T(log(T(10))); +} + +/// Computes the square root of complex number z +template +CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { + return sqrt(T(2)) / T(2) * + complex(sqrt(sqrt(norm(z)) + real(z)), + (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); +} + +/// Computes the cosine of complex z. +template +CUTLASS_HOST_DEVICE complex cos(complex const &z) { + return (exp(z) + exp(-z)) / T(2); +} + +/// Computes the sin of complex z. +template +CUTLASS_HOST_DEVICE complex sin(complex const &z) { + return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); +} + +/// Comparison +template +CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { + return true; +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for complex-valued type. +template +struct RealType< complex > +{ + using Type = T; + + /// Number of elements + static int const kExtent = 2; + + CUTLASS_HOST_DEVICE + static complex from_real(double x) { + return complex(static_cast(x)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(half_t(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(float(r)); +} + +template <> +CUTLASS_HOST_DEVICE +cutlass::complex from_real >(double r) { + return cutlass::complex(r); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct is_complex { + static bool const value = false; +}; + +template +struct is_complex> { + static bool const value = true; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()); + Output y_i = Output(lhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b.real(); + real += -a.imag() * b.imag(); + imag += a.real() * b.imag(); + imag += a.imag () * b.real(); + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, T, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + T const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b; + imag += a.imag () * b; + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + T const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a * b.real(); + imag += a * b.imag(); + + return complex{ + real, + imag + }; + } +}; + +/// Conjugate +template +struct conjugate> { + CUTLASS_HOST_DEVICE + complex operator()(complex const &a) const { + // Invoke the complex overload specifically, rather than + // wasting the compiler's effort on overload resolution. + return cutlass::conj(a); + } +}; + +#if ! defined(__CUDACC_RTC__) +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuFloatComplex operator()(cuFloatComplex const& z) const { + return make_cuFloatComplex(z.x, -z.y); + } +}; + +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuDoubleComplex operator()(cuDoubleComplex const& z) const { + return make_cuDoubleComplex(z.x, -z.y); + } +}; +#endif + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs, complex rhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()) - Output(rhs.real()); + Output y_i = Output(lhs.imag()) - Output(rhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Reduces value into the data pointed to by ptr (complex specialization) +template +struct atomic_add> { + CUTLASS_DEVICE + void operator()(complex *ptr, const complex &data) + { + data.red(ptr); + } +}; + + +////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h new file mode 100644 index 0000000000000000000000000000000000000000..f5df01726b3f4dbc88bf2fd6f15092cff2b55fac --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/constants.h @@ -0,0 +1,1239 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Boost-style constant definitions for floating-point types. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/complex.h" + +/////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace constants { + +/////////////////////////////////////////////////////////////////////////////////// + +// +// Primary templates +// + +/// Returns 1, the multiplicative identity element +template CUTLASS_HOST_DEVICE T one(); + +/// Returns 0, the additive identity element +template CUTLASS_HOST_DEVICE T zero(); + +/// Returns 2 +template CUTLASS_HOST_DEVICE T two(); + +/// Returns pi, approximately 3.141 +template CUTLASS_HOST_DEVICE T pi(); + +/// Returns 2 * pi +template CUTLASS_HOST_DEVICE T two_pi(); + +/// Returns pi / 2 +template CUTLASS_HOST_DEVICE T half_pi(); + +/// Returns sqrt(pi) +template CUTLASS_HOST_DEVICE T root_pi(); + +/// Returns sqrt(pi / 2) +template CUTLASS_HOST_DEVICE T root_half_pi(); + +/// Returns sqrt(2 * pi) +template CUTLASS_HOST_DEVICE T root_two_pi(); + +/// Returns sqrt(ln(4)) +template CUTLASS_HOST_DEVICE T root_ln_four(); + +/// Returns e, approximately 2.718... +template CUTLASS_HOST_DEVICE T e(); + +/// Returns (1/2) +template CUTLASS_HOST_DEVICE T half(); + +/// Returns sqrt(2), approximately 1.414... +template CUTLASS_HOST_DEVICE T root_two(); + +/// Returns sqrt(2)/2, approximately 0.707... +template CUTLASS_HOST_DEVICE T half_root_two(); + +/// Returns ln(2), approximately 0.693... +template CUTLASS_HOST_DEVICE T ln_two(); + +/// Returns ln(ln(2)), approximately -0.3665... +template CUTLASS_HOST_DEVICE T ln_ln_two(); + +/// Returns 1/3, approximately 0.333... +template CUTLASS_HOST_DEVICE T third(); + +/// Returns 2/3, approximately 0.666... +template CUTLASS_HOST_DEVICE T twothirds(); + +/// Returns pi - 3, approximately 0.1416... +template CUTLASS_HOST_DEVICE T pi_minus_three(); + +/// Returns 4 - pi, approximately 0.858... +template CUTLASS_HOST_DEVICE T four_minus_pi(); + + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for double + +/// Returns 1, the multiplicative identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double one() { + uint64_t bits = 0x3ff0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), double()); +} + +/// Returns 0, the additive identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double zero() { + uint64_t bits = 0x0ull; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), double()); +} + +/// Returns 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double two() { + uint64_t bits = 0x4000000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), double()); +} + +/// Returns pi, approximately 3.141 (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi() { + uint64_t bits = 0x400921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), double()); +} + +/// Returns 2 * pi (specialization for double) +template <> CUTLASS_HOST_DEVICE double two_pi() { + uint64_t bits = 0x401921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), double()); +} + +/// Returns pi / 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_pi() { + uint64_t bits = 0x3ff921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), double()); +} + +/// Returns sqrt(pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_pi() { + uint64_t bits = 0x3ffc5bf891b4ef6aull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), double()); +} + +/// Returns sqrt(pi / 2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_half_pi() { + uint64_t bits = 0x3ff40d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), double()); +} + +/// Returns sqrt(2 * pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two_pi() { + uint64_t bits = 0x40040d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), double()); +} + +/// Returns sqrt(ln(4)) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_ln_four() { + uint64_t bits = 0x3ff2d6abe44afc43ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), double()); +} + +/// Returns e, approximately 2.718... (specialization for double) +template <> CUTLASS_HOST_DEVICE double e() { + uint64_t bits = 0x4005bf0a8b145769ull; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), double()); +} + +/// Returns (1/2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double half() { + uint64_t bits = 0x3fe0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), double()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two() { + uint64_t bits = 0x3ff6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), double()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_root_two() { + uint64_t bits = 0x3fe6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), double()); +} + +/// Returns ln(2), approximately 0.693... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_two() { + uint64_t bits = 0x3fe62e42fefa39efull; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), double()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_ln_two() { + uint64_t bits = 0xbfd774f29bdd6b9full; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), double()); +} + +/// Returns 1/3, approximately 0.333... (specialization for double) +template <> CUTLASS_HOST_DEVICE double third() { + uint64_t bits = 0x3fd5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), double()); +} + +/// Returns 2/3, approximately 0.666... (specialization for double) +template <> CUTLASS_HOST_DEVICE double twothirds() { + uint64_t bits = 0x3fe5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), double()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi_minus_three() { + uint64_t bits = 0x3fc21fb54442d180ull; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), double()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for double) +template <> CUTLASS_HOST_DEVICE double four_minus_pi() { + uint64_t bits = 0x3feb7812aeef4ba0ull; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), double()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for float + +/// Returns 1, the multiplicative identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float one() { + uint32_t bits = 0x3f800000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), float()); +} + +/// Returns 0, the additive identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float zero() { + uint32_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), float()); +} + +/// Returns 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float two() { + uint32_t bits = 0x40000000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), float()); +} + +/// Returns pi, approximately 3.141 (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi() { + uint32_t bits = 0x40490fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), float()); +} + +/// Returns 2 * pi (specialization for float) +template <> CUTLASS_HOST_DEVICE float two_pi() { + uint32_t bits = 0x40c90fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), float()); +} + +/// Returns pi / 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_pi() { + uint32_t bits = 0x3fc90fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), float()); +} + +/// Returns sqrt(pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_pi() { + uint32_t bits = 0x3fe2dfc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), float()); +} + +/// Returns sqrt(pi / 2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_half_pi() { + uint32_t bits = 0x3fa06c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), float()); +} + +/// Returns sqrt(2 * pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two_pi() { + uint32_t bits = 0x40206c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), float()); +} + +/// Returns sqrt(ln(4)) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_ln_four() { + uint32_t bits = 0x3f96b55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), float()); +} + +/// Returns e, approximately 2.718... (specialization for float) +template <> CUTLASS_HOST_DEVICE float e() { + uint32_t bits = 0x402df854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), float()); +} + +/// Returns (1/2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float half() { + uint32_t bits = 0x3f000000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), float()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two() { + uint32_t bits = 0x3fb504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), float()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_root_two() { + uint32_t bits = 0x3f3504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), float()); +} + +/// Returns ln(2), approximately 0.693... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_two() { + uint32_t bits = 0x3f317218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), float()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_ln_two() { + uint32_t bits = 0xbebba795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), float()); +} + +/// Returns 1/3, approximately 0.333... (specialization for float) +template <> CUTLASS_HOST_DEVICE float third() { + uint32_t bits = 0x3eaaaaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), float()); +} + +/// Returns 2/3, approximately 0.666... (specialization for float) +template <> CUTLASS_HOST_DEVICE float twothirds() { + uint32_t bits = 0x3f2aaaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), float()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi_minus_three() { + uint32_t bits = 0x3e10fdaau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), float()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for float) +template <> CUTLASS_HOST_DEVICE float four_minus_pi() { + uint32_t bits = 0x3f5bc095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), float()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for tfloat32_t + +/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t one() { + uint32_t bits = 0x3f801000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), tfloat32_t()); +} + +/// Returns 0, the additive identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { + uint32_t bits = 0x1000u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), tfloat32_t()); +} + +/// Returns 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two() { + uint32_t bits = 0x40001000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), tfloat32_t()); +} + +/// Returns pi, approximately 3.141 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { + uint32_t bits = 0x40491fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), tfloat32_t()); +} + +/// Returns 2 * pi (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { + uint32_t bits = 0x40c91fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), tfloat32_t()); +} + +/// Returns pi / 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { + uint32_t bits = 0x3fc91fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { + uint32_t bits = 0x3fe2efc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi / 2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { + uint32_t bits = 0x3fa07c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), tfloat32_t()); +} + +/// Returns sqrt(2 * pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { + uint32_t bits = 0x40207c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), tfloat32_t()); +} + +/// Returns sqrt(ln(4)) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { + uint32_t bits = 0x3f96c55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), tfloat32_t()); +} + +/// Returns e, approximately 2.718... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t e() { + uint32_t bits = 0x402e0854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), tfloat32_t()); +} + +/// Returns (1/2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half() { + uint32_t bits = 0x3f001000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), tfloat32_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { + uint32_t bits = 0x3fb514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), tfloat32_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { + uint32_t bits = 0x3f3514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), tfloat32_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { + uint32_t bits = 0x3f318218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), tfloat32_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { + uint32_t bits = 0xbebbb795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), tfloat32_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t third() { + uint32_t bits = 0x3eaabaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), tfloat32_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { + uint32_t bits = 0x3f2abaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), tfloat32_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { + uint32_t bits = 0x3e110daau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), tfloat32_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { + uint32_t bits = 0x3f5bd095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), tfloat32_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for half_t + +/// Returns 1, the multiplicative identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t one() { + uint16_t bits = 0x3c00u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), half_t()); +} + +/// Returns 0, the additive identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), half_t()); +} + +/// Returns 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), half_t()); +} + +/// Returns pi, approximately 3.141 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi() { + uint16_t bits = 0x4248u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), half_t()); +} + +/// Returns 2 * pi (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two_pi() { + uint16_t bits = 0x4648u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), half_t()); +} + +/// Returns pi / 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_pi() { + uint16_t bits = 0x3e48u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), half_t()); +} + +/// Returns sqrt(pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_pi() { + uint16_t bits = 0x3f17u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), half_t()); +} + +/// Returns sqrt(pi / 2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { + uint16_t bits = 0x3d03u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), half_t()); +} + +/// Returns sqrt(2 * pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { + uint16_t bits = 0x4103u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), half_t()); +} + +/// Returns sqrt(ln(4)) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { + uint16_t bits = 0x3cb6u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), half_t()); +} + +/// Returns e, approximately 2.718... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t e() { + uint16_t bits = 0x4170u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), half_t()); +} + +/// Returns (1/2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half() { + uint16_t bits = 0x3800u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), half_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two() { + uint16_t bits = 0x3da8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), half_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_root_two() { + uint16_t bits = 0x39a8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), half_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_two() { + uint16_t bits = 0x398cu; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), half_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { + uint16_t bits = 0xb5ddu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), half_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t third() { + uint16_t bits = 0x3555u; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), half_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t twothirds() { + uint16_t bits = 0x3955u; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), half_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { + uint16_t bits = 0x3088u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), half_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { + uint16_t bits = 0x3adeu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), half_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for bfloat16_t + +/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t one() { + uint16_t bits = 0x3f80u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), bfloat16_t()); +} + +/// Returns 0, the additive identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), bfloat16_t()); +} + +/// Returns 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), bfloat16_t()); +} + +/// Returns pi, approximately 3.141 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { + uint16_t bits = 0x4049u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), bfloat16_t()); +} + +/// Returns 2 * pi (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { + uint16_t bits = 0x40c9u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), bfloat16_t()); +} + +/// Returns pi / 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { + uint16_t bits = 0x3fc9u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { + uint16_t bits = 0x3fe3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi / 2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { + uint16_t bits = 0x3fa0u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), bfloat16_t()); +} + +/// Returns sqrt(2 * pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { + uint16_t bits = 0x4020u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), bfloat16_t()); +} + +/// Returns sqrt(ln(4)) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { + uint16_t bits = 0x3f97u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), bfloat16_t()); +} + +/// Returns e, approximately 2.718... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t e() { + uint16_t bits = 0x402eu; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), bfloat16_t()); +} + +/// Returns (1/2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half() { + uint16_t bits = 0x3f00u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), bfloat16_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { + uint16_t bits = 0x3fb5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), bfloat16_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { + uint16_t bits = 0x3f35u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), bfloat16_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { + uint16_t bits = 0x3f31u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), bfloat16_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { + uint16_t bits = 0xbebcu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), bfloat16_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t third() { + uint16_t bits = 0x3eabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), bfloat16_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { + uint16_t bits = 0x3f2bu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), bfloat16_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { + uint16_t bits = 0x3e11u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), bfloat16_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { + uint16_t bits = 0x3f5cu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), bfloat16_t()); +} +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace constants +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e032f9599a5e76eb1e8dd6b5279ae9a42ce9c9b4 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_builder.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/collective_conv.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { + static constexpr int value = num_stages; + + StageCount() = default; + explicit StageCount(cute::Int) {} +}; + +template +struct StageCountAutoCarveout { + static constexpr int bytes = carveout_bytes; + + StageCountAutoCarveout() = default; + explicit StageCountAutoCarveout(cute::Int) {} +}; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp +struct KernelScheduleAuto {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + conv::Operator, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_gmma_builder.inl" +#include "builders/sm100_umma_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0bb596fe02b36d350a1d1065ff5001794eba170 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/collective_conv.hpp @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/detail.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class ElementB, + class TiledMma, + class TileTraitsA, + class TileTraitsB +> +struct CollectiveConv { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp" +#include "sm100_implicit_gemm_umma_warpspecialized.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af541a940f787528d213f068915ce0aa5997a82f --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/detail.hpp @@ -0,0 +1,271 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_A() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes ((w,n), C) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,d,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (k, nq/npq/nzpq) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1 || + DispatchPolicy::NumSpatialDimensions == 2 || + DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes ((q,n), K) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,z,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + +// Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_B() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes (k, (C,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, int64_t>>{}; + } + // Maps to modes (k, (C,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, int64_t, int64_t>>{}; + } + // Maps to modes (k, (C,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t, int64_t, int64_t>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (C, (w,n)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,d,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes (C, (k,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + + +template +constexpr auto +sm100_dispatch_policy_to_stride_A() { + return sm90_dispatch_policy_to_stride_A(); +} + +template +constexpr auto +sm100_dispatch_policy_to_stride_B() { + return sm90_dispatch_policy_to_stride_B(); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Compute the lower/near corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +// Computes the upper/far corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_upper_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array upper{}; + if constexpr (ConvOp == conv::Operator::kFprop) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1]; + }); + } + return upper; +} + +// Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_srt(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = 0; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +template struct is_im2col_load { static constexpr bool value = false; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3c541c325004eb8488ca7353eed9a43fa4ae280 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -0,0 +1,917 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/cluster.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/trace.h" + +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) +# include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + conv::Operator ConvOp, + int Stages, + int NumSpatialDims, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) + class ElementA_, + class ElementB_, + class TiledMma_, + class TileTraitsA_, + class TileTraitsB_> +struct CollectiveConv< + MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShapeMNKL_, + ElementA_, + ElementB_, + TiledMma_, + TileTraitsA_, + TileTraitsB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedImplicitGemm< + ConvOp, + Stages, + NumSpatialDims, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) + using ElementA = ElementA_; + using ElementB = ElementB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; + using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; + using SmemLayoutAtomA = typename TileTraitsA_::SmemLayoutAtom; + using SmemLayoutAtomB = typename TileTraitsB_::SmemLayoutAtom; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; + static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; + // deducde the kernel facing stride tuple types based on the dispatch policy (spatial dim, algo, etc.) + using StrideA = decltype(detail::sm100_dispatch_policy_to_stride_A()); + using StrideB = decltype(detail::sm100_dispatch_policy_to_stride_B()); + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using TmaInternalElementA = cute::conditional_t>>; + using TmaInternalElementB = cute::conditional_t>>; + + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using ProblemShape = ConvProblemShape; + + CUTE_STATIC_ASSERT_V(evenly_divides(shape<0>(TileShape{}), tile_size<0>(TiledMma{})), "TileShape_M should be evenly divided by TiledMma_M"); + CUTE_STATIC_ASSERT_V(evenly_divides(shape<1>(TileShape{}), tile_size<1>(TiledMma{})) || (ConvOp == conv::Operator::kWgrad), "TileShape_N should be evenly divided by TiledMma_N"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + Step<_2,_1,_3>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + static constexpr bool is_im2col_A = detail::is_im2col_load::value; + static constexpr bool is_im2col_B = detail::is_im2col_load::value; + static constexpr bool is_strided_dgrad = ConvOp == conv::Operator::kDgrad && not is_im2col_A && not is_im2col_B; + + static constexpr int TileShapeMNKLRank = rank(TileShapeMNKL_{}); + // If rank > 3, TileL exists and it is GroupsPerTile. The kernel is grouped conv now. + static constexpr bool is_grouped_wgrad = ConvOp == conv::Operator::kWgrad && TileShapeMNKLRank > 3; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof(ElementA))) + + size(AtomThrShapeMNK{}) * (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + ElementB const* ptr_B{nullptr}; + }; + +private: + + // Note that for fprop and non-strided dgrad kernel, the tma load mode is im2col for tensor A and tiled for + // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor + // B since operand A, B is swapped. + // For strided dgrad A and B are both tma tiled and not im2col + + template + static constexpr auto + get_tma_load_a_instance( + TensorA const& tensor_a, + ProblemShape const& problem_shape, + ClusterShapeVMNK const& cluster_shape_vmnk) { + + if constexpr (is_im2col_A) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + // gbasis strides for dgrad kernel need to be negated + cute::array stride_srt{}; + for (int i = 0; i < NumSpatialDimensions; ++i) { + stride_srt[i] = ConvOp == conv::Operator::kDgrad ? + -problem_shape.dilation[NumSpatialDimensions-1-i] : + problem_shape.dilation[NumSpatialDimensions-1-i]; + } + + return make_im2col_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk, + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + shape(stride_srt)); + } + // TMA tiled mode for tensor A in wgrad and strided dgrad + else { + return make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk); + } + } + + template + static constexpr auto + get_tma_load_b_instance( + TensorB const& tensor_b, + ProblemShape const& problem_shape, + ClusterShapeVMNK const& cluster_shape_vmnk) { + + if constexpr (is_im2col_B) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + return make_im2col_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk, + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + cute::reverse(shape(problem_shape.dilation))); + } + else { + return make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_shape_vmnk); + } + } + +public: + + // Performs im2col transformations on the input of type ConvProblemShape + static constexpr auto + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + if constexpr (is_im2col_A || is_im2col_B) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + } + + // Device-side kernel params + // + // Arguments has the untransformed problem shape from the user. + // Params will have the transformed problem shape. + struct Params { + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); + + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + // Assumption: StrideA is congruent with Problem_MK + // Select TMA load type according to convolution operator. + using TensorShapeA = cute::conditional_t; + + using TensorShapeB = cute::conditional_t; + + using TMA_A = decltype(get_tma_load_a_instance( + make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(TensorShapeA{}, StrideA{})), + ConvProblemShape{}, + ClusterLayout_VMNK{})); + + using TMA_B = decltype(get_tma_load_b_instance( + make_tensor( + make_gmem_ptr(recast_ptr(nullptr)), + make_layout(TensorShapeB{}, StrideB{})), + ConvProblemShape{}, + ClusterLayout_VMNK{})); + + // Members + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + // + // Constructor + // + CUTLASS_DEVICE + CollectiveConv(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + // + // Methods + // + + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple + // tma desc creation depends on the original untransformed domain. + + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + // Fill inferred cute strides from flat stride arrays + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); + Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk); + auto tma_load_a_fallback = get_tma_load_a_instance(tensor_a, problem_shape, cluster_layout_vmnk_fallback); + auto tma_load_b_fallback = get_tma_load_b_instance(tensor_b, problem_shape, cluster_layout_vmnk_fallback); + + static_assert(size(typename decltype(tma_load_a)::ThrID{}) == size(AtomThrShapeMNK{})); + static_assert(size(typename decltype(tma_load_b)::ThrID{}) == size(AtomThrShapeMNK{})); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { + // Activation and Filter channel mode extents much match + bool implementable = true; + // channel mode is major + { + const bool check = problem_shape.stride_A[NumTensorDimensions-1] == 1; +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) + if (not check) { + const auto offending_stride = + problem_shape.stride_A[NumTensorDimensions-1]; + std::ostringstream os; + os << "CollectiveConv::can_implement: " + "problem_shape.stride_A[NumTensorDimensions-1 = " + << (NumTensorDimensions-1) << "] = " + << offending_stride << " != 1"; + CUTLASS_TRACE_HOST( os.str() ); + } +#endif + implementable &= check; + } + + { + const bool check = problem_shape.stride_B[NumTensorDimensions-1] == 1; +#if (! defined(__CUDA_ARCH__)) && (CUTLASS_DEBUG_TRACE_LEVEL > 0) + if (not check) { + const auto offending_stride = + problem_shape.stride_B[NumTensorDimensions-1]; + std::ostringstream os; + os << "CollectiveConv::can_implement: " + "problem_shape.stride_B[NumTensorDimensions-1 = " + << (NumTensorDimensions-1) << "] = " + << offending_stride << " != 1\n"; + CUTLASS_TRACE_HOST( os.str() ); + } +#endif + implementable &= check; + } + + { + const auto & traversal_stride = problem_shape.traversal_stride; + for (auto stride: traversal_stride) { + implementable &= (stride >= 1 && stride <= 8); + } + } + + if constexpr (ConvOp == conv::Operator::kDgrad && not is_strided_dgrad) { + const auto & traversal_stride = problem_shape.traversal_stride; + for (auto stride: traversal_stride) { + implementable &= (stride == 1); + } + } + + constexpr int tma_alignment_bits = 128; + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + { + const bool check = cutlass::detail::check_alignment(shape_A_orig, StrideA{}); + if (not check) { + CUTLASS_TRACE_HOST("A shape and/or strides have alignment issue."); + } + implementable &= check; + } + + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + { + const bool check = cutlass::detail::check_alignment(shape_B_orig, StrideB{}); + if (not check) { + CUTLASS_TRACE_HOST("B shape and/or strides have alignment issue."); + } + implementable &= check; + } + + if (not implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + if (is_im2col_A || is_im2col_B) { + // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit] + constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; + auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; + for (int i = 0; i < problem_shape.RankS; ++i) { + // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array + implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) + && ((flt_data[i+1] - 1) * problem_shape.dilation[i] <= offset_limit); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { + + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int64_t input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int64_t output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + { + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + } + + // When groups > 1, it should be a Grouped Conv. + if (problem_shape.groups > 1) { + implementable &= TileShapeMNKLRank > 3; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Only Grouped Conv can support groups > 1.\n"); + return false; + } + } + + // Only support Grouped Wgrad currently. + if constexpr (TileShapeMNKLRank > 3) { + implementable &= ConvOp == conv::Operator::kWgrad; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv Only support Grouped Wgrad currently.\n"); + return false; + } + } + + // Grouped Wgrad channel check. + if constexpr (is_grouped_wgrad) { + + int input_K = size<0>(problem_shape.get_shape_A()); + int input_C = size<0>(problem_shape.get_shape_B()); + + implementable &= input_K == input_C; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Conv's input K and input C do not match.\n"); + return false; + } + + int output_K = size<0>(problem_shape.get_shape_C()); + int output_C = size<1,0>(problem_shape.get_shape_C()); + + implementable &= input_K == output_K; + implementable &= input_C == output_C * problem_shape.groups; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's input and output K,C and groups do not match\n"); + return false; + } + + constexpr int Tile_N = size<1>(TileShape{}); + constexpr int GroupsPerTile = size<3>(TileShapeMNKL_{}); + + implementable &= Tile_N / GroupsPerTile == input_C / problem_shape.groups; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Grouped Wgrad's Tile_N, GroupsPerTile and input_C, groups do not match.\n"); + return false; + } + } + + // The extents of linearized problem shape should be int32_t type(maximum is 2^31-1). + if constexpr (is_im2col_A || is_im2col_B) { + auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast(s); }); }; + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits::max()) & + (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits::max()); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits::max()); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n"); + return false; + } + } + + return true; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mk, tBgB_nk, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mk(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _); + auto tensor_b_coord = get<1>(cta_coord_mnkl); + if constexpr (is_grouped_wgrad) { + // in grouped wgrad, tensor A = NZPQK, tensor B = NDHWC, tensor C = KTRSc, where C = G*c, c = channel_per_group = 8,16,32. + // CTA Tiling follows output tensor KTRSc. So cta_size_m = K/CTA_TILE_M. cta_size_n = T*R*S*ceil(c/CTA_TILE_N) = T*R*S*1 = T*R*S. + // tensor_a_coord = K_idx = cta_coord_m. + // tensor_b_coord = TRS_idx * C/CTA_TILE_N + C_idx = cta_coord_n * get<1,0>(shape(tBgB_nk) + cta_coord_m, + // because K == C and CTA_TILE_M == CTA_TILE_N => C_idx = K_idx = cta_coord_m. + tensor_b_coord = get<0>(cta_coord_mnkl) + get<1>(cta_coord_mnkl) * get<1,0>(shape(tBgB_nk)); + } + Tensor tBgB = tBgB_nk(_, tensor_b_coord, _); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if constexpr (is_strided_dgrad) { + // construct gemm-k tile coord for gB + auto [conv_k, flt_coord, out_coord] = *k_tile_iter; + auto gemm_k_tile = prepend(flt_coord, conv_k); // (k,s,r,t) + + // gA doesn't have a gemm-k (k,s,r,t) iterator mode because it's not an im2col tensor + auto offset_kqpzn = append(prepend(out_coord, _0{}),_0{}); // (k,q,p,z,n) + auto tAgA_offset = make_tensor(tAgA.data() + offset_kqpzn, tAgA.layout()); // (TMA, k) + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA_offset(_,conv_k), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,gemm_k_tile) , tBsB(_,write_stage)); + } + } + else { + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mk - The tiled tma tensor for input A + /// gB_nk - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + auto K_A = conditional_return(get<0>(K), K); + Tensor mA_mk = observed_tma_load_a_->get_tma_tensor(make_shape(M, K_A)); + Tensor mB_nk = observed_tma_load_b_->get_tma_tensor(make_shape(N, K)); + + // Tile the tensors and defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mk = cta_mma.partition_A(gA_mk); // (MMA, MMA_M, MMA_K, m, k) + Tensor tCgB_nk = cta_mma.partition_B(gB_nk); // (MMA, MMA_N, MMA_K, n, k) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mk, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mk)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nk, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nk)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mk, gB_nk, // for scheduler + tAgA_mk, tBgB_nk, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + /* This helps avoid early exit of ctas in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB + > + CUTLASS_DEVICE auto + mma(MainloopPipeline pipeline, + MainloopPipelineState mainloop_pipe_consumer_state, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + int k_tile_count) + { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available (phase bit flips from mainloop_pipe_consumer_state.phase() value) + pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + return mainloop_pipe_consumer_state; + } + + CUTLASS_DEVICE auto + mma_init(TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + TiledMma tiled_mma; + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = tiled_mma.make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = tiled_mma.make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + +private: + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..11eefed94182c8d8870a65c9f4d937ede5db5421 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -0,0 +1,785 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/packed_stride.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + conv::Operator ConvOp, + int Stages, + int NumSpatialDims, + class ClusterShape, + class KernelSchedule, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class ElementB_, + class TiledMma_, + class TileTraitsA_, + class TileTraitsB_> +struct CollectiveConv< + MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>, + TileShape_, + ElementA_, + ElementB_, + TiledMma_, + TileTraitsA_, + TileTraitsB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>; + using TileShape = TileShape_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; + using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; + using SmemLayoutA = typename TileTraitsA_::SmemLayout; + using SmemLayoutB = typename TileTraitsB_::SmemLayout; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; + static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; + // Deduce the kernel-facing stride tuple types based on the dispatch policy + // (which is a function of the number of spatial dimensions, the algorithm, etc.) + using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A()); + using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B()); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + using ProblemShape = ConvProblemShape; + + static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + // The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad + // kernel is im2col for tensor A and tiled for tensor B. + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyA - invalid SM90 TMA copy atom specified."); + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyB - invalid SM90 TMA copy atom specified."); + + static constexpr bool is_im2col_A = detail::is_im2col_load::value; + static constexpr bool is_im2col_B = detail::is_im2col_load::value; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + ElementB const* ptr_B{nullptr}; + }; + +private: + // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for + // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor + // B since operand A, B is swapped. + // Get tma_load_a instantce. + template + static constexpr auto + get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) { + if constexpr (is_im2col_A) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + // The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values. + cute::array stride_srt{}; + for (int i = 0; i < NumSpatialDimensions; ++i) { + stride_srt[i] = ConvOp == conv::Operator::kDgrad ? + -problem_shape.dilation[NumSpatialDimensions-1-i] : + problem_shape.dilation[NumSpatialDimensions-1-i]; + } + + return make_im2col_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + product_each(shape(SmemLayoutA{}(_,_,_0{}))), + size<1>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + shape(stride_srt)); + } + // TMA tiled mode for tensor A in wgrad kernel. + else { + return make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); + } + } + + // Get tma_load_b instantce. + template + static constexpr auto + get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) { + // TMA im2col mode for tensor B in wgrad kernel. + if constexpr (is_im2col_B) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + return make_im2col_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + product_each(shape(SmemLayoutB{}(_,_,_0{}))), + size<0>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + cute::reverse(shape(problem_shape.dilation))); + } + else { + return make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); + } + } + +public: + + // Performs im2col transformations on the input of type ConvProblemShape + static constexpr auto + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + + if constexpr (is_im2col_A || is_im2col_B) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + } + + // Device side kernel params + struct Params { + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); + + // Assumption: StrideA is congruent with Problem_MK + // Select TMA load type according to convolution operator. + using TensorShapeA = cute::conditional_t; + + using TensorShapeB = cute::conditional_t; + + using TMA_A = decltype(get_tma_load_a_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeA{}, StrideA{})), + ConvProblemShape{})); + + using TMA_B = decltype(get_tma_load_b_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeB{}, StrideB{})), + ConvProblemShape{})); + + // Members + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + // Lowers the host side user facing arguments to the kernel facing lauch params + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple + // tma desc creation depends on the original untransformed domain. + + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + // Fill inferred cute strides from flat stride arrays + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); + Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); + + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape); + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { + // Activation and Filter channel mode extents much match + bool implementable = true; + // channel mode is major + implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1; + + constexpr int tma_alignment_bits = 128; + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_B_orig, StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + // Check valid padding values for TMA_LOAD_IM2COL + constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0; + implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + if (is_im2col_A || is_im2col_B) { + // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1] + constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; + auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; + for (int i = 0; i < problem_shape.RankS; ++i) { + // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array + implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) + && ((flt_data[i+1] - 1) * problem_shape.dilation[i] < offset_limit); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: tensor coordinate offset values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::ostringstream os; +#endif + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int64_t input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (input_stride[i] != input_shape_size) { + os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n input_shape_size: " << input_shape_size + << "\n input_shape: " << input_shape + << "\n input_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int64_t output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (output_stride[i] != output_shape_size) { + os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n output_shape_size: " << input_shape_size + << "\n output_shape: " << input_shape + << "\n output_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + + if (problem_shape.groups > 1) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); + return false; + } + + if constexpr (is_im2col_A || is_im2col_B) { + auto [M, N, K, L] = cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + auto to_64b = [](auto S) { return transform_leaf(S, [](auto s) { return static_cast(s); }); }; + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + implementable &= (cute::product(to_64b(M)) <= cutlass::platform::numeric_limits::max()) & + (cute::product(to_64b(L)) <= cutlass::platform::numeric_limits::max()); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + implementable &= (cute::product(to_64b(K)) <= cutlass::platform::numeric_limits::max()); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: the extents exceed the maximum number.\n"); + return false; + } + } + + return true; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) + /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) + /// The rest of the tensors can be specified as needed by this collective. + /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with + /// StrideA and StrideB set up for TMA + template + CUTLASS_DEVICE auto + load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ + //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k) + Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k) + + // Make tiled views, defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + return cute::make_tuple(gA_mk, gB_nk); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_producer_state, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + auto [gA_mk, gB_nk] = load_inputs; + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_producer_state for _writing_ + pipeline.producer_acquire(smem_pipe_producer_state); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); + + int write_stage = smem_pipe_producer_state.index(); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_producer_state + ++smem_pipe_producer_state; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_producer_state); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_consumer_state, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_consumer_state; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_consumer_state; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_consumer_state and smem_pipe_release + ++smem_pipe_consumer_state; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h new file mode 100644 index 0000000000000000000000000000000000000000..fbef858a54eda2ffbfea30e8ff9bd570bcf841f9 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv2d_problem_size.h @@ -0,0 +1,658 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem sizes. + + Conv2dProblem desciption: + activation (NHWC), + filter (KRSC), + output (NPQK), + pading (pad_h, pad_w), + stride (stride_h, stride_w), + dilation (dilation_h, dilation_w). + + Free functions to map: + Map tensor extents (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) + Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) + Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/functional.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Problem size structure +struct Conv2dProblemSize { + + // Conv2d strictly problem size parameters + int N, H, W, C, P, Q, K, R, S; + int pad_h, pad_w; + int stride_h, stride_w; + int dilation_h, dilation_w; + Mode mode; + + // Conv2d implementation-related parameters + int split_k_slices; + int groups; + + // + // Methods + // + +public: + CUTLASS_HOST_DEVICE + Conv2dProblemSize(): + N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), + pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), + mode(Mode::kConvolution), split_k_slices(1), groups(1) { } + + /// Constructor for default padding, stride, dilation, and split-K + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + int N, + int H, + int W, + int C, + int P, + int Q, + int K, + int R, + int S, + Mode mode + ): + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), + mode(mode), split_k_slices(1), groups (1) { } + + /// Constructor + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + int N, + int H, + int W, + int C, + int K, + int R, + int S, + int P, + int Q, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + Mode mode, + int split_k_slices = 1, + int groups = 1 + ): + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), + pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), + dilation_h(dilation_h), dilation_w(dilation_w), + mode(mode), split_k_slices(split_k_slices), groups (groups) { } + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // set user-defined output size and sets P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ + cutlass::MatrixCoord stride, // stride_h, stride_w + cutlass::MatrixCoord dilation, // dilation_h, dilation_w + cutlass::Tensor4DCoord output_size, // NPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(padding[0]), pad_w(padding[2]), + stride_h(stride.row()), stride_w(stride.column()), + dilation_h(dilation.row()), dilation_w(dilation.column()), + mode(mode), split_k_slices(split_k_slices), groups(groups) {} + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // computes output size and sets P and Q (skip output from ctor arguments) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w + cutlass::MatrixCoord stride, // stride_h, stride_w + cutlass::MatrixCoord dilation, // dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(padding[0]), pad_w(padding[2]), + stride_h(stride.row()), stride_w(stride.column()), + dilation_h(dilation.row()), dilation_w(dilation.column()), + mode(mode), split_k_slices(split_k_slices), groups(groups) { + // set output P and Q + P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1; + Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1; + } + + /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord + // set user-defined output size and sets P and Q (skip padding, striding, and dilation) + CUTLASS_HOST_DEVICE + Conv2dProblemSize( + cutlass::Tensor4DCoord input_size, // NHWC + cutlass::Tensor4DCoord filter_size, // KRSC + cutlass::Tensor4DCoord output_size, // NPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), + K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), + dilation_h(1), dilation_w(1), + mode(mode), split_k_slices(split_k_slices), groups(groups) {} + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv2dProblemSize reset_mode(cutlass::conv::Mode mode_) { + Conv2dProblemSize tmp(*this); + tmp.mode = mode_; + return tmp; + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv2dProblemSize reset_split_k_slices(int split_k_slices_) { + Conv2dProblemSize tmp(*this); + tmp.split_k_slices = split_k_slices_; + return tmp; + } + + /// Equality operator (ignores mode and split_k_slice) + CUTLASS_HOST_DEVICE + bool operator==(Conv2dProblemSize const &conv) const { + return ( + (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && + (K == conv.K) && (R == conv.R) && (S == conv.S) && + (P == conv.P) && (Q == conv.Q) && + (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && + (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && + (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) + ); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(Conv2dProblemSize const &rhs) const { + return !(*this == rhs); + } + + /// Returns activation extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord activation_extent() const { + + return cutlass::Tensor4DCoord ({N, H, W, C}); + } + + /// Returns filter extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const { + + return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups}) + : cutlass::Tensor4DCoord ({K, R, S, C / groups}); + } + + /// Returns output extent as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord output_extent() const { + + return cutlass::Tensor4DCoord ({N, P, Q, K}); + } + + /// Returns activation size in number of elements + CUTLASS_HOST_DEVICE + int64_t activation_size() const { + + return static_cast(N) * static_cast(H) * + static_cast(W) * static_cast(C); + } + + /// Returns filter size in number of elements + CUTLASS_HOST_DEVICE + int64_t filter_size() const { + + return static_cast(K) * static_cast(R) * + static_cast(S) * static_cast(C) / + static_cast(groups); + } + + /// Returns output size in number of elements + CUTLASS_HOST_DEVICE + int64_t output_size() const { + + return static_cast(N) * static_cast(P) * + static_cast(Q) * static_cast(K); + } + + /// Returns padding as Tensor4DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor4DCoord padding() const { + + return cutlass::Tensor4DCoord ({pad_h, pad_h, pad_w, pad_w}); + } + + /// Returns stride as MatrixCoord + CUTLASS_HOST_DEVICE + cutlass::MatrixCoord stride() const { + + return cutlass::MatrixCoord ({stride_h, stride_w}); + } + + /// Returns dilation as MatrixCoord + CUTLASS_HOST_DEVICE + cutlass::MatrixCoord dilation() const { + + return cutlass::MatrixCoord ({dilation_h, dilation_w}); + } + + ///////////////////////////////////////////////////////////////// + // Methods used for strided dgrad implementation + ///////////////////////////////////////////////////////////////// + /// Number of filter r positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_r(int r) const { + return ((R - r + stride_h - 1) / stride_h); + } + + /// Number of filter s positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_s(int s) const { + return ((S - s + stride_w - 1) / stride_w); + } + + /// Number of filter positions to accumulate in gemm-k dim + CUTLASS_HOST_DEVICE + int num_gemm_k_filter_positions(int r, int s) const { + return num_gemm_k_filter_r(r) * num_gemm_k_filter_s(s); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ImplicitGemm helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Determine the problem size of the implicit GEMM operation +CUTLASS_HOST_DEVICE +cutlass::gemm::GemmCoord implicit_gemm_problem_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + // Compute problem size + switch (conv_operator) { + case Operator::kFprop: + return gemm::GemmCoord( + problem_size.N * problem_size.P * problem_size.Q, + problem_size.K, + problem_size.R * problem_size.S * problem_size.C / problem_size.groups + ); + case Operator::kDeconv: + case Operator::kDgrad: + return gemm::GemmCoord( + problem_size.N * problem_size.H * problem_size.W, + problem_size.C, + problem_size.R * problem_size.S * problem_size.K + ); + case Operator::kWgrad: + return gemm::GemmCoord( + problem_size.K, + problem_size.R * problem_size.S * problem_size.C, + problem_size.N * problem_size.P * problem_size.Q + ); + default: + break; + } + return gemm::GemmCoord(); +} + +// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int iterations = 0; + + if (group_mode == GroupMode::kNone) { + + if (algorithm == IteratorAlgorithm::kFixedChannels) { + + int positions_per_iteration = threadblock_K / problem_size.C; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; + break; + + default: + break; + } + } + else if (algorithm == IteratorAlgorithm::kFewChannels) { + + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; + break; + + default: + break; + } + } + else { + int elements_per_split_k_slice = 0; + + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDeconv: + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } + + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } else { // Group conv + + int channels_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (problem_size.groups != 1) { + if (k_per_group < threadblock_N) { + iterations *= threadblock_N / k_per_group; + } + } + break; + + default: + break; + } + } else if (algorithm == IteratorAlgorithm::kOptimized) { + // Current optimized iterator only support GroupMode::kSingleGroup + if (group_mode == GroupMode::kSingleGroup) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } + + } + + return iterations; +} + + +template +CUTLASS_HOST_DEVICE +int depthwise_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int n = problem_size.N; + int p = (problem_size.P + Output_P - 1) / Output_P; + int q = (problem_size.Q + Output_Q - 1) / Output_Q; + + int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + return iterations; +} + + +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations_per_channel( + Operator conv_operator, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + + int iterations = 0; //0 means not applicable + if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S; + break; + + case Operator::kDeconv: + case Operator::kDgrad: + iterations = problem_size.R * problem_size.S; + break; + + default: + break; + } + } + return iterations; +} + +//////////////////////////////////////////////////////////////////////////////// +// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) +//////////////////////////////////////////////////////////////////////////////// +/// Returns ImplicitGemm tensor A extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor B extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor C extent as Tensor4DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); + default : break; + } + return cutlass::Tensor4DCoord(); +} + +/// Returns ImplicitGemm tensor A size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_a_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor B size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_b_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor C size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_c_size( + Operator conv_operator, + Conv2dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); + default : break; + } + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Strided dgrad helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Returns number of CTAs tile M to cover valid MMAs per starting filter postion +CUTLASS_HOST_DEVICE +int strided_dgrad_tile_m_per_filter( + Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // Compute NHW rows in Dx output that needs MMA per starting filter position + int rows_h_per_filter = (problem_size.H + problem_size.stride_h - 1) / problem_size.stride_h; + int rows_w_per_filter = (problem_size.W + problem_size.stride_w - 1) / problem_size.stride_w; + int rows_nhw_per_filter = problem_size.N * rows_h_per_filter * rows_w_per_filter; + + // Number of CTAs tile M to cover valid MMAs per starting filter postion + int tile_m_per_filter = (rows_nhw_per_filter + tile_size_m - 1) / tile_size_m; + + return tile_m_per_filter; +} + +// Computes starting Dx coord (h, w) for given starting filter postion +CUTLASS_HOST_DEVICE +void strided_dgrad_starting_coords( + Conv2dProblemSize const &problem_size, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int r, int s, + int &start_h, int &start_w) { + + // function locals for remainder by fast divmod + int pad_h_rem_, pad_w_rem_; + + // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); + int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); + stride_h_divmod.divmod(start_h, r_); + + //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); + int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); + stride_w_divmod.divmod(start_w, s_); +} + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h new file mode 100644 index 0000000000000000000000000000000000000000..48bf056e17014400a6bc41b87193a05de3cb9c96 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/conv3d_problem_size.h @@ -0,0 +1,519 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem sizes. + + Conv3dProblem desciption: + activation (NDHWC), + filter (KTRSC), + output (NZPQK), + pading (pad_d, pad_h, pad_w), + stride (stride_d, stride_h, stride_w), + dilation (dilation_d, dilation_h, dilation_w). + + Free functions to map: + Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator) + Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) + Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) +*/ + +#pragma once + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Problem size structure +struct Conv3dProblemSize : public Conv2dProblemSize { + // + // Type definitions + // + + // 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions + using Coord3D = Coord<3>; + + // + // Data members + // + + // Conv3d strictly problem size parameters + int D, T, Z; // input depth, filter depth, output depth + int pad_d; // padding in depth dimension + int stride_d; // stride in depth dimension + int dilation_d; // dilation in depth dimension + + // + // Methods + // +public: + CUTLASS_HOST_DEVICE + Conv3dProblemSize(): + Conv2dProblemSize(), + D(0), T(0), Z(0), + pad_d(0), + stride_d(1), + dilation_d(1) { } + + /// Constructor for default padding, stride, dilation, and split-K + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + int N, + int D, + int H, + int W, + int C, + int Z, + int P, + int Q, + int K, + int T, + int R, + int S, + Mode mode + ): + Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode), + D(D), T(T), Z(Z), + pad_d(T / 2), stride_d(1), dilation_d(1) { } + + /// Constructor + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + int N, + int D, + int H, + int W, + int C, + int K, + int T, + int R, + int S, + int Z, + int P, + int Q, + int pad_d, + int pad_h, + int pad_w, + int stride_d, + int stride_h, + int stride_w, + int dilation_d, + int dilation_h, + int dilation_w, + Mode mode, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + N, H, W, C, K, R, S, P, Q, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w, + mode, split_k_slices, groups), + D(D), T(T), Z(Z), + pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { } + + /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D + // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + Coord3D padding, // pad_d, pad_h, pad_w + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::Tensor5DCoord output_size, // NZPQK + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {padding[1], padding[1], padding[2], padding[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), Z(output_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { } + + /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D + // *computes* output size and sets Z, P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + Coord3D padding, // pad_d, pad_h, pad_w + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {padding[1], padding[1], padding[2], padding[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { + // set output Z + Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; + } + + /// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D + // *computes* output size and sets Z, P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + CUTLASS_STL_NAMESPACE::tuple padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1], + CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { + // set output Z + Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1; + } + + /// Equality operator (ignores mode and split_k_slice) + CUTLASS_HOST_DEVICE + bool operator==(Conv3dProblemSize const &conv) const { + return ( + (N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) && + (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && + (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && + (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && + (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && + (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) + ); + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(Conv3dProblemSize const &rhs) const { + return !(*this == rhs); + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) { + Conv3dProblemSize tmp(*this); + tmp.mode = mode_; + return tmp; + } + + // Reset covolution mode in the problem + CUTLASS_HOST_DEVICE + Conv3dProblemSize reset_split_k_slices(int split_k_slices_) { + Conv3dProblemSize tmp(*this); + tmp.split_k_slices = split_k_slices_; + return tmp; + } + + /// Returns activation extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord activation_extent() const { + + return cutlass::Tensor5DCoord ({N, D, H, W, C}); + } + + /// Returns filter extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const { + + return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K}) + : cutlass::Tensor5DCoord ({K, T, R, S, C}); + } + + /// Returns output extent as Tensor5DCoord + CUTLASS_HOST_DEVICE + cutlass::Tensor5DCoord output_extent() const { + + return cutlass::Tensor5DCoord ({N, Z, P, Q, K}); + } + + /// Returns activation size in number of elements + CUTLASS_HOST_DEVICE + int64_t activation_size() const { + + return static_cast(N) * static_cast(D) * + static_cast(H) * static_cast(W) * + static_cast(C); + } + + /// Returns filter size in number of elements + CUTLASS_HOST_DEVICE + int64_t filter_size() const { + + return static_cast(K) * static_cast(T) * + static_cast(R) * static_cast(S) * + static_cast(C); + } + + /// Returns output size in number of elements + CUTLASS_HOST_DEVICE + int64_t output_size() const { + + return static_cast(N) * static_cast(Z) * + static_cast(P) * static_cast(Q) * + static_cast(K); + } + + /// Returns padding as Coord3D + CUTLASS_HOST_DEVICE + Coord3D padding() const { + + return Coord3D ({pad_d, pad_h, pad_w}); + } + + /// Returns stride as MatrixCoord + CUTLASS_HOST_DEVICE + Coord3D stride() const { + + return Coord3D ({stride_d, stride_h, stride_w}); + } + + /// Returns dilation as MatrixCoord + CUTLASS_HOST_DEVICE + Coord3D dilation() const { + + return Coord3D ({dilation_d, dilation_h, dilation_w}); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ImplicitGemm helper functions // +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Determine the problem size of the implicit GEMM operation +CUTLASS_HOST_DEVICE +cutlass::gemm::GemmCoord implicit_gemm_problem_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + // Compute problem size + switch (conv_operator) { + case Operator::kFprop: + return gemm::GemmCoord( + problem_size.N * problem_size.Z * problem_size.P * problem_size.Q, + problem_size.K, + problem_size.T * problem_size.R * problem_size.S * problem_size.C + ); + case Operator::kDeconv: + case Operator::kDgrad: + return gemm::GemmCoord( + problem_size.N * problem_size.D * problem_size.H * problem_size.W, + problem_size.C, + problem_size.T * problem_size.R * problem_size.S * problem_size.K + ); + case Operator::kWgrad: + return gemm::GemmCoord( + problem_size.K, + problem_size.T * problem_size.R * problem_size.S * problem_size.C, + problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + ); + default: + break; + } + return gemm::GemmCoord(); +} + +// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv3dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int iterations = 0; + int elements_per_split_k_slice = 0; + if (group_mode == GroupMode::kNone) { + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDeconv: + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.T * problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } + + return iterations; +} + +//////////////////////////////////////////////////////////////////////////////// +// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) +//////////////////////////////////////////////////////////////////////////////// +/// Returns ImplicitGemm tensor A extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor B extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); + case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor C extent as Tensor5DCoord +CUTLASS_HOST_DEVICE +cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); + default : break; + } + return cutlass::Tensor5DCoord(); +} + +/// Returns ImplicitGemm tensor A size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_a_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor B size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_b_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); + default : break; + } + return 0; +} + +/// Returns ImplicitGemm tensor C size in number of elements +CUTLASS_HOST_DEVICE +int64_t implicit_gemm_tensor_c_size( + Operator conv_operator, + Conv3dProblemSize const &problem_size) { + switch (conv_operator) { + case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: + case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); + case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); + default : break; + } + return 0; +} + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3c31c21b2508914d10d41bb865a6da145bf3c106 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convnd_problem_shape.hpp @@ -0,0 +1,601 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem shapes. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/conv/convolution.h" + +#include "cute/container/array.hpp" + +#if ! defined(__CUDACC_RTC__) +#include +#endif + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion. +// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK) +// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. +template < + conv::Operator ConvOp_, + int NumSpatialDimensions_ +> +struct ConvProblemShape { + // + // Alias types for members + // + + static constexpr int RankS = NumSpatialDimensions_; + static constexpr int RankT = NumSpatialDimensions_ + 2; + static constexpr conv::Operator ConvOp = ConvOp_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + using SpatialExtent = cute::array; + using TensorExtent = cute::array; + using TensorStride = cute::array; + using ShapePadding = SpatialExtent; + using TraversalStride = SpatialExtent; + using ShapeDilation = SpatialExtent; + using Corner = SpatialExtent; + + // + // Members + // + cutlass::conv::Mode mode{}; + TensorExtent shape_A{}; + TensorStride stride_A{}; + TensorExtent shape_B{}; + TensorStride stride_B{}; + TensorExtent shape_C{}; + TensorStride stride_C{}; + + // asymmetric padding, both upper and lower padding must be >= 0 + ShapePadding lower_padding{}; + ShapePadding upper_padding{}; + TraversalStride traversal_stride{}; + ShapeDilation dilation{}; + int groups = 1; + + // + // Methods + // + + ConvProblemShape() = default; + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + TensorStride stride_xformed_act, // [n,z,p,q,k] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order. + ConvProblemShape( + conv::Mode mode, + TensorExtent shape_act, + TensorExtent shape_flt, + ShapePadding lower_padding, + ShapePadding upper_padding, + TraversalStride tstride, + ShapeDilation dilation, + int groups) + : ConvProblemShape( + mode, + shape_act, + packed_stride_right_major(shape_act), + shape_flt, + packed_stride_right_major(shape_flt), + lower_padding, + upper_padding, + tstride, + dilation, + groups) { + } + +#if ! defined(__CUDACC_RTC__) + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(stride_act_.size() == stride_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(stride_flt_.size() == stride_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list stride_xformed_act_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + TensorStride stride_xformed_act{}; + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list shape_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + stride_act = packed_stride_right_major(shape_act); + stride_flt = packed_stride_right_major(shape_flt); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } +#endif // not defined(__CUDACC_RTC__) + + // Set shape and stride of tensor A/B/C according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ------ | ------ | ------ | ------| + // | ShapeA | NDHWC | NZPQK | NZPQK | + // | ShapeB | KTRSC | KTRSC | NDHWC | + // | ShapeC | NZPQK | NDHWC | KTRSC | + // + // Input comes from calculate_xformed_act, which does NOT depend on ConvOp. + CUTLASS_HOST_DEVICE + constexpr void + set_shape_stride_ABC( + TensorExtent shape_act, + TensorStride stride_act, + TensorExtent shape_flt, + TensorStride stride_flt, + TensorExtent shape_xformed_act, + TensorStride stride_xformed_act) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("*** set_shape_stride_ABC ***"); + printf("\n shape_act: "); + print(shape_act); + printf("\n stride_act: "); + print(stride_act); + printf("\n shape_flt: "); + print(shape_flt); + printf("\n stride_flt: "); + print(stride_flt); + printf("\n shape_xformed_act: "); + print(shape_xformed_act); + printf("\n stride_xformed_act: "); + print(stride_xformed_act); + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + printf("\n ConvOp: Fprop"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + printf("\n ConvOp: Dgrad"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + printf("\n ConvOp: Wgrad"); + } + printf("\n"); +#endif + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + shape_A = shape_act; + stride_A = stride_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_xformed_act; + stride_C = stride_xformed_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_act; + stride_C = stride_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_act; + stride_B = stride_act; + shape_C = shape_flt; + stride_C = stride_flt; + } +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n shape_A: "); + print(shape_A); + printf("\n stride_A: "); + print(stride_A); + printf("\n shape_B: "); + print(shape_B); + printf("\n stride_B: "); + print(stride_B); + printf("\n shape_C: "); + print(shape_C); + printf("\n stride_C: "); + print(stride_C); +#endif + } + + // Get A extents. + // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_A() const { + using cute::make_shape; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + cute::reverse(take<0, RankT - 1>(shape_A)), + shape_A[RankT - 1]); + } + // For wgrad kernel, we need to linearize NZPQ for tensor A + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_A[RankT - 1], + cute::product(take<0, RankT - 1>(shape_A))); + } + } + + // Get B extents. + // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) + // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_B() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop) { + return make_shape( + shape_B[0], + reverse(take<1, RankT>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_B[RankT - 1], + reverse(take<0, RankT - 1>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)] + return make_shape( + shape_B[RankT - 1], + cute::insert<0>( + reverse(take<1, RankT - 1>(shape_B)), + shape_B[0])); + } + } + + // Get C extents. + // fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_C() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + reverse(take<0, RankT - 1>(shape_C)), + shape_C[RankT - 1]); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_C[0], + reverse(take<1, RankT>(shape_C))); + } + } + + // Static method that returns the canonical strides of tensors (layouts are right major and compact) + CUTLASS_HOST_DEVICE + static constexpr TensorStride + packed_stride_right_major(TensorExtent const& extents) { + TensorStride strides{}; + strides[RankT-1] = 1; + cute::for_each(cute::make_rseq{}, [&](auto i) { + strides[i] = extents[i+1] * strides[i+1]; + }); + return strides; + } + + // Static method that returns the packed logical size of any TensorExtent + CUTLASS_HOST_DEVICE + static constexpr size_t + size(TensorExtent const& extents) { + size_t size = 1; + cute::for_each(cute::make_seq{}, [&](auto i) { + size *= extents[i]; + }); + return size; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_A() const { + return shape_A[0] * stride_A[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_B() const { + return shape_B[0] * stride_B[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_C() const { + return shape_C[0] * stride_C[0]; + } + + // Equality operator + CUTLASS_HOST_DEVICE + bool operator==(ConvProblemShape const& rhs) const { + using cute::for_each; + using cute::make_seq; + + bool is_equal = true; + + // Compare all tensor extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (shape_A[i] == rhs.shape_A[i]) + && (shape_B[i] == rhs.shape_B[i]); + }); + + // Compare all spatial extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (lower_padding[i] == rhs.lower_padding[i]) + && (upper_padding[i] == rhs.upper_padding[i]) + && (traversal_stride[i] == rhs.traversal_stride[i]) + && (dilation[i] == rhs.dilation[i]); + }); + + return is_equal; + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(ConvProblemShape const &rhs) const { + return !(*this == rhs); + } + +private: + CUTLASS_HOST_DEVICE + constexpr auto + calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) { + TensorExtent shape_xformed_act{}; + // calculate n,z,p,q,k. + // a helper lambda to compute a single spatial extent of the nzpqk tensor + auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { + return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; + }; + + shape_xformed_act[0] = shape_act[0]; // Activation N extent + cute::for_each(cute::make_seq{}, [&](auto i) { + shape_xformed_act[i+1] = nzpqk_extent( + shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]); + }); + shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent + + TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act); + + return cute::make_tuple(shape_xformed_act, stride_xformed_act); + } +}; + +template< + conv::Operator ConvOp, + int SpatialDim +> +void print(ConvProblemShape const& problem) { + printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n", + SpatialDim, int(ConvOp)); + printf("\tTensorA: "); + cute::print(problem.shape_A); printf(":"); + cute::print(problem.stride_A); printf("\n"); + printf("\tTensorB: "); + cute::print(problem.shape_B); printf(":"); + cute::print(problem.stride_B); printf("\n"); + printf("\tTensorC: "); + cute::print(problem.shape_C); printf(":"); + cute::print(problem.stride_C); printf("\n"); + printf("\tLower padding: "); print(problem.lower_padding); printf("\n"); + printf("\tUpper padding: "); print(problem.upper_padding); printf("\n"); + printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n"); + printf("\tDilation: "); print(problem.dilation); printf("\n"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..a3cc98b4740115aefd557468d01ad28fa9a1028a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/convolution.h @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + +This file contains definitions and utility functions for describing convolution problem sizes in terms of +activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and +dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm +tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types. + + * Mapping convolutions to Gemm computation * + +Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm +(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output. +The underlying gemm operation follows the standard gemm definition: + + C = A * B + C + + A and B are input matrices + C is source and output matrix + + +For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped +to convolution tensors Activation, Filter and Output as described in the table below. + + ___________________________________________________________________________ + ConvolutionalOperator | A | B | C + ___________________________________________________________________________ + | | | | | + | Fprop | Activation | Filter | Output | + | Dgrad | Output | Filter | Activation | + | Wgrad | Output | Activation | Filter | + ___________________________________________________________________________ + +In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output). + +For example, it's confusing and error prone to document a convolution class or function +as operating on "A, B, Output." Instead, use the mapping functions below, +and adhere to using either A, B, C or Activation, Filter, Output. + +Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap +Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/matrix_coord.h" + +namespace cutlass { +namespace conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Convolutional operator +enum class Operator { + kFprop, + kDgrad, + kWgrad, + kDeconv +}; + +/// Distinguishes convolution from cross correlation +enum class Mode { + kCrossCorrelation, + kConvolution +}; + +/// Selects among several implementation variants trading off performance with simplicity +enum class IteratorAlgorithm { + kAnalytic, ///< functionally correct in all cases but lower performance + kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad + kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) + kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFixedStrideDilation ///< Optimized for fixed stride and dilation +}; + +/// Distinguishes among partial specializations that accelerate certain problems where convolution +/// stride is unit. +enum class StrideSupport { + kStrided, ///< arbitrary convolution stride + kUnity, ///< unit convolution stride + kFixed ///< fixed convolution stride +}; + +/// Identifies split-K mode +enum class SplitKMode { + kNone, + kSerial, + kParallel +}; + +/// Identifies group mode +enum class GroupMode { + kNone, + kSingleGroup, ///< One CTA calculates one group or less + kMultipleGroup, ///< One CTA calculates multiple groups + kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a tensor +template < + int N = 1, + int H = 1, + int W = 1, + int C = 1 +> +struct TensorNHWCShape { + static int const kN = N; + static int const kH = H; + static int const kW = W; + static int const kC = C; + + static int const kHW = H * W; + static int const kNHW = N * kHW; + static int const kNHWC = N * H * W * C; + + static int const kCount = kNHWC; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<4> toCoord() { + return make_Coord(kN, kH, kW, kC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a conv2d stride, which controls how the filter convolves around the input volume +template < + /// Stride in horizontal direction + int u = 1, + /// Stride in vertical direction + int v = 1 +> +struct Stride2D { + static int const kU = u; + static int const kV = v; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<2> toCoord() { + return make_Coord(kU, kV); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0802921d60ce1809a7da67805de0f045c3511b19 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/detail.hpp @@ -0,0 +1,137 @@ + +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + // Helper function to get the problem shape +template +auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) { + return T::get_problem_shape_MNKL(problem_shape); +} + +template +ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) { + return problem_shape; +} + +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | +// | Shape_L | _1 | (V,U,O) | _1 | + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) { + return problem_shape; +} + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + using cute::insert; + using cute::make_shape; + using cute::reverse; + using cute::take; + + constexpr int RankT = SpatialDim + 2; + + if constexpr (ConvOp == conv::Operator::kWgrad) { + auto M_xformed = problem_shape.shape_C[0]; + auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C)); + auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kFprop){ + auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C)); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O) + auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] + auto K_xformed = insert<0>( + (reverse(take<1,RankT - 1>(problem_shape.shape_B))), + problem_shape.shape_B[0]); + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } +} + +// Assuming im2col linearization +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | +// | Shape_L | _1 | (V*U*O) | _1 | +template +CUTLASS_HOST_DEVICE +constexpr auto +get_linearized_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + + auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape); + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return cute::make_shape(cute::product(M), N, K, cute::product(L)); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cute::make_shape(M, N, cute::product(K), L); + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d60469f429f94f4b8152a02d9db232eea5698e56 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -0,0 +1,448 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +// common +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/arch/mma.h" +#include "cutlass/trace.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/device_kernel.h" + +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + ConvUniversalAdapter is a stateful, reusable handle built around a kernel + of type cutlass::conv::kernel::ConvUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, static methods + are exposed that bypass the stateful methods or args->params lowering. +*/ +template +class ConvUniversalAdapter +{ +public: + using ConvKernel = GetUnderlyingKernel_t; + using TileShape = typename ConvKernel::TileShape; + using ElementA = typename ConvKernel::ElementA; + using ElementB = typename ConvKernel::ElementB; + using ElementC = typename ConvKernel::ElementC; + using ElementD = typename ConvKernel::ElementD; + using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC; + using DispatchPolicy = typename ConvKernel::DispatchPolicy; + using CollectiveMainloop = typename ConvKernel::CollectiveMainloop; + using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + // Tease out meta-information about the conv algorithm + static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; + + // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! + using OperatorClass = cute::conditional_t< + (cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1), + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + + using ArchTag = typename ConvKernel::ArchTag; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = ConvKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + /// Argument structure: User API + using Arguments = typename ConvKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename ConvKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the conv can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (ConvKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += ConvKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace); + return ConvKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return ConvKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = ConvKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + ConvKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes conv state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = ConvKernel::to_underlying_arguments(args, workspace); + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // account for dynamic smem capacity if needed + int smem_size = ConvKernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = ConvKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + CUTLASS_TRACE_HOST("ConvUniversal::run()"); + dim3 const block = ConvKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = ConvKernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; + dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || + ConvKernel::ArchTag::kMinComputeCapability == 101) { + if constexpr (!cute::is_static_v) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + void* kernel_params[] = {¶ms}; + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + launch_result = cuda_adapter->launch(grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel_params, + kernel_index); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90 + || ConvKernel::ArchTag::kMinComputeCapability == 100 + ) { + if constexpr (is_static_1x1x1) { + device_kernel<<>>(params); + launch_result = Status::kSuccess; + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + } + else { + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || + ConvKernel::ArchTag::kMinComputeCapability == 101) { + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, + cluster, + fallback_cluster, + block, + smem_size, + stream, + kernel, + kernel_params); + } + } + } + } + else { + launch_result = Status::kSuccess; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + device_kernel<<>>(params); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0 + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, kernel_index); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(args, workspace, stream, cuda_adapter); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..387574b989681ba6f9e5e6fa333dda109b7f7aa6 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/direct_convolution.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Depthwise Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConvolution { +public: + + using UnderlyingKernel = DirectConvolutionKernel_; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + + using ReorderKernel = typename UnderlyingKernel::ReorderKernel; + + private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + DirectConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + if (kGroupMode != conv::GroupMode::kDepthwise) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K != args.problem_size.groups && + args.problem_size.C != args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.ptr_reordered_B = args.ref_reordered_B.data(); + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // Launch reorder kernel + if (params_.ptr_reordered_B != nullptr) { + dim3 grid = ReorderKernel::get_grid_shape(params_); + dim3 block = ReorderKernel::get_block_shape(); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + } + + // Launch main kernel + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + // Dynamic SMEM size based on input params. + int smem_size = int(params_.get_smem_size()); + + // Make sure we can use that much shared memory. + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status != cudaSuccess) + return Status::kErrorInternal; + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + + int get_smem_size() { return int(params_.get_smem_size()); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..a9aae87bc1c57a20e27298b4f227726dd199a769 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -0,0 +1,388 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Implicit GEMM Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ImplicitGemmConvolution { +public: + + using UnderlyingKernel = GetUnderlyingKernel_t; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + +private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + ImplicitGemmConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + // Check that tensor sizes don't exceed maximum supported size + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.activation_size() * sizeof(ElementA) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + else if (kConvolutionalOperator == conv::Operator::kDgrad || + kConvolutionalOperator == conv::Operator::kDeconv) { + if (args.problem_size.activation_size() * sizeof(ElementC) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.activation_size() * sizeof(ElementB) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + + // check group conv constraint + if (args.problem_size.groups != 1) { + if (kGroupMode == conv::GroupMode::kNone) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K % args.problem_size.groups || + args.problem_size.C % args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + // split-k is not supported + if (args.problem_size.split_k_slices != 1) { + return Status::kErrorInvalidProblem; + } + + int k_per_group = args.problem_size.K / args.problem_size.groups; + // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group + if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { + return Status::kErrorInvalidProblem; + } + // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups + if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { + return Status::kErrorInvalidProblem; + } + + // current optimized iterator algo only supports SingleGroup mode + if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && + kGroupMode != conv::GroupMode::kSingleGroup) { + return Status::kErrorInvalidProblem; + } + } + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // check for unsupported problem sizes for strided dgrad / deconv implementation + if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && + kStrideSupport == conv::StrideSupport::kStrided) { + // split-k (serial or parallel) is not supported for strided dgrad / deconv + if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) { + return Status::kErrorNotSupported; + } + + // dilation > {1x1} is not supported for strided dgrad / deconv + if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { + return Status::kErrorNotSupported; + } + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to optain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + if (args.problem_size.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + cutlass::Status launch_result = cutlass::Status::kSuccess ; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + void* kernel_params[] = {¶ms_}; + launch_result = cuda_adapter->launch( + grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index + ); + } + else { + launch_result = Status::kErrorInternal; + } + } + else { + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + return run(stream, cuda_adapter, kernel_index); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter, kernel_index); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..efd3dcbad093cf8d11036a63a9b6638d1801aeee --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level fused activation's scale+bias+relu and Implicit GEMM Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ImplicitGemmConvolutionFusion { +public: + + using ImplicitGemmFusionKernel = ImplicitGemmFusionKernel_; + + using ElementA = typename ImplicitGemmFusionKernel::ElementA; + using LayoutA = typename ImplicitGemmFusionKernel::LayoutA; + using ElementB = typename ImplicitGemmFusionKernel::ElementB; + using LayoutB = typename ImplicitGemmFusionKernel::LayoutB; + +// using ElementScaleBias = typename ImplicitGemmFusionKernel::ElementScaleBias; +// using LayoutScaleBias = typename ImplicitGemmFusionKernel::LayoutScaleBias; + + using ElementC = typename ImplicitGemmFusionKernel::ElementC; + using LayoutC = typename ImplicitGemmFusionKernel::LayoutC; + using ElementAccumulator = typename ImplicitGemmFusionKernel::ElementAccumulator; + using ElementCompute = typename ImplicitGemmFusionKernel::ElementCompute; + using OperatorClass = typename ImplicitGemmFusionKernel::OperatorClass; + using ArchTag = typename ImplicitGemmFusionKernel::ArchTag; + using ThreadblockShape = typename ImplicitGemmFusionKernel::ThreadblockShape; + using WarpShape = typename ImplicitGemmFusionKernel::WarpShape; + using InstructionShape = typename ImplicitGemmFusionKernel::InstructionShape; + using ThreadblockSwizzle = typename ImplicitGemmFusionKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename ImplicitGemmFusionKernel::EpilogueOutputOp; + static int const kStages = ImplicitGemmFusionKernel::kStages; + static int const kConvDim = ImplicitGemmFusionKernel::kConvDim; + using WarpMmaOperator = typename ImplicitGemmFusionKernel::WarpMmaOperator; + using ArchMmaOperator = typename ImplicitGemmFusionKernel::ArchMmaOperator; + using MathOperator = typename ImplicitGemmFusionKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmFusionKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmFusionKernel::kIteratorAlgorithm; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename ImplicitGemmFusionKernel::Arguments; + +private: + + /// Kernel parameters object + typename ImplicitGemmFusionKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + ImplicitGemmConvolutionFusion() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = ImplicitGemmFusionKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = ImplicitGemmFusionKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size), + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to optain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + if (args.problem_size.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename ImplicitGemmFusionKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes Impicit GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_scale = args.ref_A_scale.data(); + params_.ptr_bias = args.ref_A_bias.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d569cb1c3e6d6c7da188691a94384d43259d2be0 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/dispatch_policy.hpp @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convolution.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/arch/arch.h" + +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" + +////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +////////////////////////////////////////////////////////////////////////////// + +// +// Policies for categorical dispatch of mainloop against kernel grid schedules +// +struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { }; +struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; +struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; + +// +// Collective Mainloop Policies +// + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA +// for fprop +template< + conv::Operator ConvOp_, + int Stages_, + int NumSpatialDimensions_, + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>>, + class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90, + int PipelineAsyncMmaStages_ = 1 +> +struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { + static constexpr int Stages = Stages_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + static constexpr Operator ConvOp = ConvOp_; + static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + + static_assert(NumSpatialDimensions >= 1); + static_assert(! (cute::is_same_v || + cute::is_same_v), + "Persistent schedules not support for conv yet."); +}; + + + +// SM100 tensor op kernel schedule +struct KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = 0; + static constexpr int AccumulatorPipelineStageCount = 0; +}; + +// Pseudo-policies for builder auto override that dispatches to the KernelImplicitTmaWarpSpecializedSm100 +// but for opting into 1 or 2 SM atoms +struct KernelImplicitTmaWarpSpecialized1SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; +struct KernelImplicitTmaWarpSpecialized2SmSm100 : KernelImplicitTmaWarpSpecializedSm100 { }; + +struct KernelStridedDgradTmaWs1SmSm100 { }; +struct KernelStridedDgradTmaWs2SmSm100 { }; + +// Policy for implicit gemm kernel +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelScheduleImplicitTmaWarpSpecializedSm100 : KernelImplicitTmaWarpSpecializedSm100 { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +// n-buffer in smem (Blackwell TMA), pipelined with Blackwell UMMA and TMA, fprop +template< + conv::Operator ConvOp_, + int Stages_, + int NumSpatialDimensions_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> +> +struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { + static constexpr int Stages = Stages_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + static constexpr Operator ConvOp = ConvOp_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100; + + static_assert(NumSpatialDimensions >= 1); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv + +////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..af804df30e76a156af33f7095da64614370e466c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/conv_universal.hpp @@ -0,0 +1,65 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/detail/dependent_false.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device CONV kernel type that treats CONV as + * a composition of a collective mainloop and a collective epilogue. +**/ +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ = void, + class Enable = void +> +class ConvUniversal { + static_assert(cutlass::detail::dependent_false, + "Could not find a valid specialization at the kernel layer to dispatch against."); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + +//////////////////////////////////////////////////////////////////////////////// +#include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp" +#include "cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h new file mode 100644 index 0000000000000000000000000000000000000000..f9647a598799cf233962457f8d2cad7e59e46cf5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d.h @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions for threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" +#include "cutlass/conv/threadblock/implicit_gemm_pipelined.h" +#include "cutlass/conv/threadblock/implicit_gemm_multistage.h" +#include "cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h" +#include "cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_fusion.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogue { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogue< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp +> { + + using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess, + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 +> +struct DefaultConvEpilogueWithBroadcastSimt { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess, + false, + PermuteDLayout, + StrideSupport, + Rank + >::Epilogue; +}; + +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename ElementOutput, + typename OutputOp, + typename ReductionOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithReductionTensorOp< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + > { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithReductionVoltaTensorOp< + Shape, + WarpMmaTensorOp, + PartitionsK, + ElementOutput, + OutputOp, + ReductionOp, + ElementsPerAccess + >::Epilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Defaults for strided Dgrad +template < + typename ArchTag, + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +template < + typename Shape, + typename WarpMmaTensorOp, + int PartitionsK, + typename OutputOp +> +struct DefaultConvEpilogueStridedDgrad< + arch::Sm70, + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp +> { + + using Epilogue = typename epilogue::threadblock::DefaultEpilogueVoltaTensorOpStridedDgrad< + Shape, + WarpMmaTensorOp, + PartitionsK, + OutputOp, + OutputOp::kCount + >::Epilogue; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..27a96a5602494e2abe3980b3d07d54c49dcb9932 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -0,0 +1,1927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dDgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dDgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided and +// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided +// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm Dgrad Unity +// 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for optimized IteratorAlgorithm Dgrad Unity Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided and +// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOpStridedDgrad< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided +// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kStrided, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogueStridedDgrad< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm Dgrad Unity +// 2 stage pipeline +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dDgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + StrideSupport::kUnity + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..932d1abdc6e2c80a4a1e8d1eb805cbedbe5ac78a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -0,0 +1,2007 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kFixedChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFixedChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFewChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kFewChannels, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorFewChannels< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +// multistage pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + layout::TensorNCxHWx, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + layout::TensorCxRSKx, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB, + int InterleavedK +> +struct DefaultConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::SmemThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::SmemThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..85b142a0e27d3c39d2d742c1709582ea3156b801 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity +> struct DefaultConv2dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..ccc751535c7a8c2c2f49b8d34f9d0e9a8edbd90e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines a default configuration for convolution with absolute maximum calculation. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_absmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_absmax.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithAbsMax { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithAbsMax< + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementAuxOutput, + ElementC, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithAbsMax< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..b7fca981b0e0b44dca2b9add89808ac2b036d021 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..5c2c7ffc700b089e449d4f18008c26cdb8d6c81a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_reduction.h" +#include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename EpilogueReductionOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithReduction { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + EpilogueOutputOp, + EpilogueReductionOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..99e353d80a0b3b37818371737c8189eee6b5ed38 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dGroupFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline that supports all GroupMode. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and +/// 2 stage pipeline that supports all GroupMode. + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage +/// pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and +/// 2 stage pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..d55d453eb02675d0b626865b6625dc4bf2b12e92 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -0,0 +1,1011 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dWgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AccessTypeA, + int AccessTypeB +> +struct DefaultConv2dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + AccessTypeA, + AccessTypeB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..83b680ec3591de39470013d71b808f356306b2f0 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv2dWgradFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dWgradFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmWgradFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + IteratorScaleBias, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv2dWgradFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmWgradFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + IteratorScaleBias, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..309924cebafe82df1651b0fb5542eb14dc6c5388 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_dgrad.h @@ -0,0 +1,736 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dDgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dDgrad; + +/// Defines a kernel for Conv3dDgrad specialization for Analytic IteratorAlgorithm Dgrad Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm Dgrad Strided +// and multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..4b6709f08a4b2e93a0e3b93e1a343896368451c2 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -0,0 +1,981 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity +> struct DefaultConv3dFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..513de059c6591a47fbf2c75f81d1400c96fe9d48 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity +> struct DefaultConv3dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..2fb12c2a502f9af2aa5383288e6695a108abdf60 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv3dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv3dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..6b50d2087e20889a934eaf34c7f120badff8a435 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_conv3d_wgrad.h @@ -0,0 +1,936 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dWgrad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and multistage +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm and two +// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, OperatorClass, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h new file mode 100644 index 0000000000000000000000000000000000000000..a58046ffa414e6556d14b20c5402fb5d82cfbf64 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d.h @@ -0,0 +1,999 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDeconv2d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..e62187e3680e55a71d77bf4fee19276357753f98 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv2dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h new file mode 100644 index 0000000000000000000000000000000000000000..cb7ca07e6eb9b18f3006d51e742772f755852e23 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultDeconv3d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..e25c8b2eee551252b902e0c0845416b753194df1 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -0,0 +1,309 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv3dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..ba70813e4c94104522a05897a60811d26ae3c6a4 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" +#include "cutlass/conv/kernel/direct_convolution.h" + +#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" + +// Direct Conv Related Header files +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" + +#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value +> struct DefaultDepthwiseFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop with direct convolution algorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + // MatrixShape + typename StrideShape = cutlass::MatrixShape<-1, -1>, + // MatrixShape< Height, Width> + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDepthwiseDirect2dConvFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, // cutlass::arch::OpMultiplyAdd + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< + ThreadblockShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + sizeof_bits::value, + 2, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + cutlass::conv::GroupMode::kDepthwise + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + + + + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); + static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); + + // Activations loaded by threadblock + static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ActivationShape = + cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape, + DilationShape, + ActivationShape>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + StrideShape, + DilationShape, + ActivationShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue, + IteratorAlgorithm::kFixedStrideDilation + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..8c04988790b9b03e41e9c2245dbdf2e5e8af493b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/direct_convolution.h @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multi-staged Depthwise Convolution kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template > ///! OutputShape per ThreadBlock +struct DirectConvolutionParams { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + using Arguments = Arguments_; + using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; + + using ThreadblockShape = typename Mma::Shape; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static conv::GroupMode const kGroupMode = GroupMode_; + static int const kStages = Mma::kStages; + + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + int smem_size_; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Mma::IteratorB::Element *ptr_reordered_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + int split_k_slices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} + + /// + CUTLASS_HOST_DEVICE + DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) + : problem_size(args.problem_size), + implicit_gemm_problem_size( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), + ptr_B(args.ref_B.data()), + ptr_reordered_B(args.ref_reordered_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + split_k_slices(args.problem_size.split_k_slices) { + gemm_k_iterations = + depthwise_gemm_k_iterations(kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + + // Dynamic SMEM usage because stride and dilation are runtime params. + smem_size_ = (cutlass::platform::max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); + } + + CUTLASS_HOST_DEVICE + int get_smem_size() { + // Dynamic Smem Size + return smem_size_; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ReorderKernel { + using Params = Params_; + using ElementB = ElementB_; + + union SharedStorage {}; + + static unsigned int const kReorderKernelThreadPerCTA = 128; + + CUTLASS_HOST_DEVICE + ReorderKernel() {} + + CUTLASS_HOST_DEVICE + static dim3 get_grid_shape(Params const ¶ms) { + return dim3{static_cast( + (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / + kReorderKernelThreadPerCTA), + 1, + 1}; + } + + CUTLASS_HOST_DEVICE + static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } + + CUTLASS_HOST_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + int64_t m = static_cast(params.problem_size.groups); + int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); + const ElementB *src_with_type = static_cast(params.ptr_B); + ElementB *dst_with_type = static_cast(params.ptr_reordered_B); + + int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; + int64_t index_m = linear_index / n; + int64_t index_n = linear_index % n; + int64_t new_linear_index = index_m + index_n * m; + + if (linear_index < m * n) { + dst_with_type[new_linear_index] = src_with_type[linear_index]; + } + return; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +struct DirectConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefB ref_reordered_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + TensorRefB const & ref_reordered_B = nullptr, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + ref_reordered_B(ref_reordered_B), + split_k_mode(split_k_mode) + { + + } + + }; + + using Params = + typename cutlass::conv::kernel::DirectConvolutionParams; + + using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if threadblock is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_column_offset = 0; + int filter_row_offset = 0; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode == GroupMode::kDepthwise) { + iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + iterator_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_reordered_B, + thread_idx, + MatrixCoord( + filter_row_offset, + iterator_column_offset + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + + // Compute threadblock-scoped matrix multiply-add + // Epilogue is fused in the mainloop + mma(params.gemm_k_iterations, + accumulators, + iterator_A, + params.iterator_A, + iterator_B, + params.iterator_B, + accumulators, + epilogue, + output_op, + iterator_D, + iterator_C, + params.split_k_slices); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..d3fa0e907bb94c2716861395324b2da0346cebde --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -0,0 +1,455 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode +> +struct ImplicitGemmConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations( + kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode != GroupMode::kDepthwise) { + int k_per_group = params.problem_size.K / params.problem_size.groups; + int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; + int channels_per_group = params.problem_size.C / params.problem_size.groups; + iterator_A_column_offset += group_idx * channels_per_group; + } else { + iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + iterator_A_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..5451c176f4027bc40a3ec3466efe69dea18f5342 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined fused activation's scale+bias+relu and Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionFusion { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + + using ElementScaleBias = typename Mma::IteratorScaleBias::Element; + using LayoutScaleBias = typename Mma::IteratorScaleBias::Layout; + + using ElementC = typename EpilogueOutputOp::ElementOutput; + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefScaleBias = typename Mma::IteratorScaleBias::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefScaleBias ref_scale; + TensorRefScaleBias ref_bias; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefScaleBias const & ref_scale, + TensorRefScaleBias const & ref_bias, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_bias(ref_bias), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + gemm::GemmCoord implicit_gemm_problem_size{}; + int swizzle_log_tile{0}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Mma::IteratorScaleBias::Params iterator_scale_bias{}; + typename Mma::IteratorScaleBias::Element const *ptr_scale = nullptr; + typename Mma::IteratorScaleBias::Element const *ptr_bias = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C {}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D {}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; + + // + // Methods + // + Params() = default; + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_scale_bias(args.problem_size, args.ref_scale.layout()), + ptr_scale(args.ref_scale.data()), + ptr_bias(args.ref_bias.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionFusion() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Construct iterators to A scale/bias vector + typename Mma::IteratorScaleBias iterator_scale_bias( + params.iterator_scale_bias, + params.problem_size, + params.ptr_scale, + params.ptr_bias, + thread_idx, + MatrixCoord( + 0, (kConvolutionalOperator == conv::Operator::kFprop) ? + (threadblock_tile_idx.k() * Mma::Shape::kK) : + // Wgrad + (threadblock_tile_idx.n() * Mma::Shape::kN) + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, + iterator_B, iterator_scale_bias, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Run efficient epilogue + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h new file mode 100644 index 0000000000000000000000000000000000000000..071854cd629e26417ca987bc24681665c8d30702 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -0,0 +1,492 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionStridedDgrad { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // Strided dgrad uses a specialized threadblock swizzle for functionality and performance + static_assert((platform::is_same::value) || + (platform::is_same>::value) || + (platform::is_same>::value) || + (platform::is_same>::value), + "Needs ThreadblockSwizzle type specialized for strided dgrad"); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size{}; + TensorRefA ref_A{}; + TensorRefB ref_B{}; + TensorRefC ref_C{}; + TensorRefC ref_D{}; + typename EpilogueOutputOp::Params output_op{}; + SplitKMode split_k_mode{}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + FastDivmod stride_h_divmod{}; + FastDivmod stride_w_divmod{}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C{}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D{}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; + + // + // Methods + // + Params() = default; + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + stride_h_divmod(args.problem_size.stride_h), + stride_w_divmod(args.problem_size.stride_w), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size, ThreadblockShape::kM), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size, ThreadblockShape::kM), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionStridedDgrad() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Compute starting filter position for strided dgrad + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(params.problem_size, + ThreadblockShape::kM); + int filter_tile_m = (threadblock_tile_idx.m() / tile_m_per_filter); + + + // The subsequent fast_divmod() operations are equivalent to the following logical computation: + // + // int start_r = filter_tile_m / (params.problem_size.stride_w); + // int start_s = filter_tile_m % (params.problem_size.stride_w); + + int start_r, start_s; + params.stride_w_divmod(start_r, start_s, filter_tile_m); + + int filter_r = start_r; + int filter_s = start_s; + + if (params.problem_size.mode == Mode::kConvolution) { + filter_r = (params.problem_size.R - 1 - filter_r); + filter_s = (params.problem_size.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + params.problem_size, + params.stride_h_divmod, params.stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { + return; + } + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA + if (start_r < params.problem_size.R && start_s < params.problem_size.S) { + // Scale gemm_k_iterations for strided dgrad + int gemm_k_iterations = (params.gemm_k_iterations / (params.problem_size.R * params.problem_size.S) + ) * params.problem_size.num_gemm_k_filter_positions(start_r, start_s); + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + start_r, start_s, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + threadblock_offset + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + if (output_op.is_source_needed()) + { + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + threadblock_offset); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + } + + // Run epilogue with addend source iterator + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + // Run epilogue without addend source iterator + epilogue(output_op, iterator_D, accumulators); + } + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..0113473f9b28d7c657c07ff8f85e34fc66ea1ed1 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Convolution kernel with an epilogue that computes the absolute maximum value of the output + and a pre-activation-function auxiliary output. The auxiliary output is also (optionally) + stored to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithAbsMax { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + using TensorRefAux = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + TensorRefC ref_Aux; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + TensorRefAux const & ref_Aux, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + typename LayoutC::Stride::Index ldr = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + ref_Aux(ref_Aux), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ldr(ldr) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename Epilogue::AuxOutputTileIterator::Params iterator_Aux; + typename Epilogue::AuxOutputTileIterator::Element *ptr_Aux; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + iterator_Aux(ConvOutputIteratorParameter::layout(args.ref_Aux)), + ptr_Aux(args.ref_Aux.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithAbsMax() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to auxiliary tensor. + typename Epilogue::AuxOutputTileIterator iterator_Aux( + params.iterator_Aux, + params.ptr_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + iterator_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..1e810e3d13c8b8eed4894ac9670f4a586dcaef8d --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithFusedEpilogue { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + void * ptr_Tensor; + + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + void * ptr_Tensor = nullptr, + typename LayoutC::Stride::Index ldr = 0, + typename LayoutC::Stride::Index ldt = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + ldr(ldr), + ldt(ldt) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + typename Epilogue::TensorTileIterator::Params params_Tensor; + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + void * ptr_Tensor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + params_Tensor(args.ldt), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithFusedEpilogue() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + typename Epilogue::ElementTensor *ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Tensor, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..327fc27db4eba8093ce58845e465071da724c2e8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -0,0 +1,874 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/arch/cluster_sm90.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/sm100_pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class ConvUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + + // Mainloop derived types + using ProblemShape = ProblemShape_; + using CollectiveMainloop = CollectiveMainloop_; + + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; + static constexpr bool is_grouped_wgrad = CollectiveMainloop::is_grouped_wgrad; + static constexpr bool IsComplex = false; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + // TileID scheduler + // CLC pipeline depth determines how many waves (stages-1) the scheduler can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename cutlass::gemm::kernel::detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + + // Pipelines and pipeline states + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineDetail::PipelineCLCFetchAsyncPipelineState; + using CLCPipelineSharedStorage = cutlass::PipelineDetail::PipelineCLCFetchAsyncSharedStorage; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = CLCPipelineSharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + using ProblemShapeMNKL = decltype(CollectiveMainloop::get_problem_shape_MNKL(ProblemShape{})); + ProblemShapeMNKL problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + }; + + // + // Methods + // + // Map user facing arguments to device facing params + CUTLASS_HOST + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + + auto problem_shape_mnkl = CollectiveMainloop::get_problem_shape_MNKL(args.problem_shape); + + auto mainloop_params = CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace, args.hw_info); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shape_mnkl, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + problem_shape_mnkl, + mainloop_params, + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace), + args.hw_info + }; + } + + CUTLASS_HOST + static bool + can_implement(Arguments const& args) { + bool implementable = true; + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + + // implicit gemm B tile can be small for conv, ensure multicast smem offsets are 128B aligned + int multicast_b_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape)) * sizeof_bits_v; + int multicast_b_fallback_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape_fallback)) * sizeof_bits_v; + implementable &= multicast_b_bits % (128*8) == 0 && multicast_b_fallback_bits % (128*8) == 0; + if (not implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: multicast size too large for B tile\n"); + return false; + } + + if constexpr (is_grouped_wgrad) { + implementable &= size<0>(cluster_shape) == 1 && size<0>(cluster_shape_fallback) == 1; + + if (!implementable) { + return false; + } + } + + return implementable; + } + + CUTLASS_HOST + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + CUTLASS_HOST + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + auto linear_problem_shape_MNKL = cutlass::conv::detail::get_linearized_problem_shape_MNKL(args.problem_shape); + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace + ( + args.scheduler, workspace_ptr + workspace_offset, stream, linear_problem_shape_MNKL, + args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + + workspace_offset += TileScheduler::template get_workspace_size + ( + args.scheduler, linear_problem_shape_MNKL, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, + CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + CUTLASS_HOST + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + return TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape + ,params.hw_info + ); + } + + CUTLASS_HOST + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + auto problem_shape_MNKL = append<4>(params.problem_shape, _1{}); + auto [M, N, K, L] = problem_shape_MNKL; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : WarpCategory::Epilogue; + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + collective_mainloop.prefetch_tma_descriptors(); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + collective_epilogue.prefetch_tma_descriptors(params.epilogue); + } + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue) // epilogue + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 1; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.initializing_warp = 3; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 4; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 5; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + uint32_t tmem_stage_ptrs[AccumulatorPipelineStageCount]; + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, problem_shape_MNKL, TileShape{}, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto accumulators = TiledMma::make_fragment_C(acc_shape); + + int TmemColumnsPerAccumulatorTile = cutlass::detail::find_tmem_tensor_col_offset(accumulators); + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + Tensor gA_mk = get<0>(load_inputs); + + do { + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, TileShape{}, shape<3>(gA_mk)); + auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); + auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + + else if (is_participant.sched) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + + else if (is_participant.mma) { + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + + CUTLASS_PRAGMA_UNROLL + for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { + tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; + } + auto mma_inputs = collective_mainloop.mma_init(shared_storage.tensors.mainloop); + + do { + auto k_tile_count = scheduler.get_work_k_tile_count(work_tile_info, problem_shape_MNKL, TileShape{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + + // Accumulator stage slice + int acc_stage = accumulator_pipe_producer_state.index(); + accumulators.data() = tmem_stage_ptrs[acc_stage]; + + if (is_mma_leader_cta) { + mainloop_pipe_consumer_state = collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + mma_inputs, + k_tile_count + ); + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + CUTLASS_PRAGMA_UNROLL + for (int acc_stage = 0; acc_stage < AccumulatorPipelineStageCount; acc_stage++) { + tmem_stage_ptrs[acc_stage] = tmem_base_ptr + (TmemColumnsPerAccumulatorTile * acc_stage) & cutlass::detail::TmemColMask; + } + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice after making sure allocation has been performed + int acc_stage = accumulator_pipe_consumer_state.index(); + accumulators.data() = tmem_stage_ptrs[acc_stage]; + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c02a4531edd4078da6c92205f36b62b237c20bc --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class ConvUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t> +> : public cutlass::gemm::kernel::GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_ +> +{}; +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..41eaba2f64b1c14fd85de632b1bfe8c9a3efbc1e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/thread/depthwise_mma.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for depthwise convolution +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/thread/mma.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// MMA operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Inner product operator + typename Operator +> +struct ElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// General implementation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_> +struct ElementwiseInnerProduct { + using Shape = Shape_; + using Operator = arch::OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_HOST_DEVICE + void operator()(Array &d, + Array const &a, + Array const &b, + Array const &c) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Shape::kN; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization of half_t +template <> +struct ElementwiseInnerProduct< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + half_t, + half_t, + arch::OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = arch::OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 tmp_D = __hfma2(A, B, C); + + d = reinterpret_cast const &>(tmp_D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Concept: arch::OpMultiplyAdd or arch::Mma<> + typename Operator = arch::OpMultiplyAdd, + /// Used for partial specialization + typename Enable = bool +> +struct DepthwiseDirectConvElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gemplate that handles all packed matrix layouts +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_, + /// Operator used to compute GEMM + typename Operator_ +> +struct DepthwiseDirectConvElementwiseInnerProductGeneric { + + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = Operator_; + + /// A operand storage + using FragmentA = Array; + + /// B operand storage + using FragmentB = Array; + + /// C operand storage + using FragmentC = Array; + + /// Instruction + using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< + gemm::GemmShape, + 1, + ElementA, + ElementB, + ElementC, + Operator>; + + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + Array *ptr_D = reinterpret_cast *>(&D); + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array const *ptr_B = + reinterpret_cast const *>(&B); + + MmaOp mma_op; + + // Copy accumulators + D = C; + + // Compute matrix product + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { + + Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpB = ptr_B[n]; + + mma_op(tmpD, tmpA, tmpB, tmpD); + + ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_ +> +struct DepthwiseDirectConvElementwiseInnerProduct< + Shape_, + ElementA_, + ElementB_, + ElementC_, + arch::OpMultiplyAdd + > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = arch::OpMultiplyAdd; + + /// A operand storage + using FragmentA = + Array; // output_tile_size per thread * groups_per_thread + + /// B operand storage + using FragmentB = Array; // 1 * groups_per_thread + + /// C operand storage + using FragmentC = + Array; // output_tile_size per thread * groups_per_thread + + static bool const use_optimized = 0; + + using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + + ArchMmaOperator mma; + + mma(D, A, B, C); + + } +}; + +} // namespace thread +} // namespace conv +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..2da2b73b3afe3d5f5800c84d2edb2b220003ba83 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -0,0 +1,485 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradFilterTileAccessIteratorAnalytic; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic strided dgrad needs special handling to skip MMAs +// on non-contributing w positions +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Moves filter_s + filter_s_ += problem_size_.stride_w; + if (filter_s_ < problem_size_.S) { + return; + } + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r + filter_r_ += problem_size_.stride_h; + if (filter_r_ < problem_size_.R) { + return; + } + // Restore filter_r + filter_r_ = start_r_; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorAnalytic unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +>{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // For a fixed filter position (r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_r_; + int filter_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = offset_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..8a5e60b9d134d8ec5d28da7e486bc5c7f6629a39 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -0,0 +1,619 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradFilterTileAccessIteratorOptimized; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ + > { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Parameters structure + // + + struct Params : Conv2dStridedDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dStridedDgradFilterIteratorOptimizedParams const &base): + Conv2dStridedDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dStridedDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_k_; + int filter_r_; + int filter_s_; + + int start_r_; + int start_s_; + + int64_t reset_bytes_s_; + int64_t reset_bytes_r_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized( + Conv2dStridedDgradFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; + reset_bytes_r_ = reset_bytes_s_ + + (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + TensorCoord coord{filter_k_, filter_r_, filter_s_, column}; + + pointer_ += params_.layout(coord) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void advance() { + + int next_idx = 0; + LongIndex reset_bytes = params_.reset_bytes; + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ >= problem_size_.S) { + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; +#if 0 + bool check = (filter_r_ < problem_size_.R); + + filter_r_ = check ? filter_r_ : start_r_; + next_idx = check ? 1 : 2; + reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .s64 t1;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 t1, %6, %7, %%p;\n\t" + " add.s64 %2, %8, t1;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); +#endif + } + + // offset pointers by offset_bytes + pointer_ += (params_.inc_next[next_idx] - reset_bytes); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradFilterTileAccessIteratorOptimized unity strided dgrad is more performant for dgrad +// on problem sizes with stride = {1x1} +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradFilterTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ + > { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Parameters structure + // + + struct Params : Conv2dDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dDgradFilterIteratorOptimizedParams const &base): + Conv2dDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv2dDgradFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_rs_; + int filter_k_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized( + Conv2dDgradFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_rs_(0), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + uint32_t pred = ((filter_k < problem_size_.K && (filter_c + v * AccessType::kElements) < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + pointer_ += ( + filter_k_ * params.layout.stride()[2] + column + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_rs; + + // moves to the next tile + ++filter_rs_; + if (filter_rs_ == params_.RS) { + + filter_rs_ = 0; + next = params_.inc_next_k; + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..b33645c1783c8b12cc9d8d6e1d93dbffb3f47f1c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using +// unscaled coordinations +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int filter_r = filter_r_; + int filter_s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + filter_r = (problem_size_.R - 1 - filter_r); + filter_s = (problem_size_.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + problem_size_, + stride_h_divmod, stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + // Effective P and Q for filter position required for remapping NHW rows + int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; + int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; + + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; + + // (STEP 1) [reorder NHW rows to start with same filter positions] + offset_n_[s] = offset_npq / (P * Q); + int residual = offset_npq % (P * Q); + + int p = (residual / Q); + int q = (residual % Q); + + int mapped_h = (start_h + p * problem_size_.stride_h); + int mapped_w = (start_w + q * problem_size_.stride_w); + + // Access (p, q) coordinates for Dy tensor and a filter position in gemm_k=0 + // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are divisible + // by stride_h and stride_w + offset_p_[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; + offset_q_[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ < problem_size_.S) { + return; + } + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; + if (filter_r_ < problem_size_.R) { + return; + } + + // Restore filter_r + filter_r_ = start_r_; + + // Move filter_k + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int conv_sign = (problem_size_.mode == Mode::kConvolution ? 1 : -1); + + p += (conv_sign * (filter_r_ / problem_size_.stride_h)); + q += (conv_sign * (filter_s_ / problem_size_.stride_w)); + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord( + n, + p, + q, + k); + } + + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return + coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2dDgradOutputGradientTileAccessIteratorAnalytic for unity strides can be optimized by +// eliminating modulo arithmetic to compute unscaled coordinates +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int filter_r_; + int filter_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_w_[ThreadMap::Iterations::kStrided]; + int offset_h_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_nhw / (problem_size_.H * problem_size_.W); + int residual = offset_nhw % (problem_size_.H * problem_size_.W); + + offset_h_[s] = residual / problem_size_.W; + offset_w_[s] = residual % problem_size_.W; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // move to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int n = offset_n_[iteration_strided_]; + int h = offset_h_[iteration_strided_]; + int w = offset_w_[iteration_strided_]; + + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h) / problem_size_.stride_h; + int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w) / problem_size_.stride_w; + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, p, q, k); + } + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // Conv2dDgradFilterTileAccessIteratorAnalytic unity stride specialization + // only supports (stride_h, stride_w) = (1, 1) + if (problem_size.stride() != MatrixCoord({1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..638c6607095ce85f7c1b135296d974bdf295621a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized; +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Conv2dDgradOutputGradientTileAccessIteratorOptimized strided dgrad needs special handling +// to skip MMAs (Dx = Dy * w) on invalid filter positions +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided, + AccessType_ +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = uint64_t; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dStridedDgradOutputGradientIteratorOptimizedParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + int filter_k_; + int filter_r_; + int filter_s_; + int start_r_; + int start_s_; + int64_t reset_bytes_s_; + int64_t reset_bytes_r_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_r_(start_r), + filter_s_(start_s), + start_r_(start_r), + start_s_(start_s) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + reset_bytes_s_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0]; + + reset_bytes_r_ = (problem_size_.num_gemm_k_filter_s(start_s_) - 1) * params_.inc_next[0] + + (problem_size_.num_gemm_k_filter_r(start_r_) - 1) * params_.inc_next[1]; + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + int filter_r = filter_r_; + int filter_s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + filter_r = (problem_size_.R - 1 - filter_r); + filter_s = (problem_size_.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + problem_size_, + stride_h_divmod, stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + + // Effective starting P and Q for filter position required for remapping NHW rows + int P = (problem_size_.H - start_h + problem_size_.stride_h - 1) / problem_size_.stride_h; + int Q = (problem_size_.W - start_w + problem_size_.stride_w - 1) / problem_size_.stride_w; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_npq = (threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided) % params_.tiled_rows_per_filter; + + // (STEP 1) [reorder NHW rows to start with same filter positions] + offset_n[s] = offset_npq / (P * Q); + int residual = offset_npq % (P * Q); + + int p = (residual / Q); + int q = (residual % Q); + + int mapped_h = (start_h + p * problem_size_.stride_h); + int mapped_w = (start_w + q * problem_size_.stride_w); + + // Access (p, q) coordinates for Dy tensor for filter position in gemm_k=0 + // note that (h + pad_h - filter_r) and (w + pad_w - filter_s) are ensured to be + // divisible by stride_h and stride_w + offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; + offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; + + // Initialize pointers for gemm_k=0 + TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + // + // Precompute mask predicates + // + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = start_r; r < problem_size_.R; r += problem_size_.stride_h) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int p = offset_p[s_idx] ; + + p += (params_.conv_sign * (r / problem_size_.stride_h)); + + bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for(int s = start_s; s < problem_size_.S; s += problem_size_.stride_w) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int q = offset_q[s_idx]; + q += (params_.conv_sign * (s / problem_size_.stride_w)); + + bool pred = (q >=0 && q < problem_size_.Q); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size.K); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}); + } + +private: + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset, LongIndex byte_reset = 0) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset - byte_reset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void advance() { + + int next_idx = 0; + int64_t reset_bytes = 0; + + // Move filter_s by stride_w + filter_s_ += problem_size_.stride_w; + if (filter_s_ >= problem_size_.S) { + + // Restore filter_s + filter_s_ = start_s_; + + // Move filter_r by stride_h + filter_r_ += problem_size_.stride_h; +#if 0 + if (filter_r_ < problem_size_.R) { + + next_idx = 1; + + // Restore bytes in q coordinate (Mma in filter s dimension) + reset_bytes = reset_bytes_s_; + + } else { + + // Restore filter_r + filter_r_ = start_r_; + + next_idx = 2; + + // Restore bytes in p and q coordinate (Mma in filter s and r dimension) + reset_bytes = reset_bytes_r_; + } +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 %2, %6, %7, %%p;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_)); +#endif + } + + // offset pointers by offset_bytes + add_byte_offset_(params_.inc_next[next_idx] - reset_bytes); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // Limit on filter size + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Conv2dDgradOutputGradientTileAccessIteratorOptimized unity stride dgrad is optimized for dgrad +// with problem stride = {1x1} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ +> +class Conv2dDgradOutputGradientTileAccessIteratorOptimized < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kUnity, + AccessType_ +> { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dDgradOutputGradientIteratorOptimizedParams; + +private: + + Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (r, s) + int filter_r_; + int filter_s_; + int filter_k_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized( + Conv2dDgradOutputGradientIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_h[ThreadMap::Iterations::kStrided]; + int offset_w[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_nhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_nhw / (problem_size_.H * problem_size_.W); + // int residual = offset_nhw % (problem_size_.H * problem_size_.W); + // + // offset_h[s] = residual / problem_size_.W; + // offset_w[s] = residual % problem_size_.W; + // + + int residual; + + params_.hw_divmod(offset_n[s], residual, offset_nhw); + params_.w_divmod(offset_h[s], offset_w[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_h[s], offset_w[s], 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; + + bool pred = (offset_n[s_idx] < problem_size_.N && p >= 0 && p < problem_size_.P); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; + + bool pred = (q >= 0 && q < problem_size_.Q); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the output gradient tensor dy that is correspoinding to + // activation nhw and filter position k, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int h, int w, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; + int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; + + return TensorCoord(n, p, q, filter_k_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + filter_s_ = 0; + ++filter_r_; + + if (filter_r_ < problem_size_.R) { + next_idx = 1; + } + else { + filter_r_ = 0; + next_idx = 2; + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 2) { + filter_k_ += params_.filter_k_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, (filter_k_ + v_idx * AccessType::kElements) >= problem_size_.K); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? Mask(0) : masks_[s][v][0]; + masks_[s][v][1] = clear ? Mask(0) : masks_[s][v][1]; + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // This is specialized for unit stride + if (problem_size.stride() != MatrixCoord({1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorNotSupported; + } + + // Limit on filter size + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..e4eb011e1c675757b9f1fa3111c2de0db658cad5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -0,0 +1,332 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone +> +class Conv2dFpropActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_c_; + int filter_r_; + int filter_s_; + int filter_c_init_; + int group_idx_offset_; + int channels_per_group_; + int crs_cnt_; + int crs_per_group_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_(0), + filter_c_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + // moves to next group + crs_cnt_ = 0; + ++group_idx_offset_; + filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; + } else { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + } + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - filter_r_); + s = (problem_size_.S - 1 - filter_s_); + } + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + int c = filter_c_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h new file mode 100644 index 0000000000000000000000000000000000000000..c608ce5305039ce42bd017fd74f14658a6c593da --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorFewChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kPositionsPerTile = Shape::kColumn; + + static int const kAccessesPerVector = kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static int const kStrideH = 0; + static int const kStrideW = 0; + static int const kDilationH = 0; + static int const kDilationW = 0; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rsc_index_; + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFewChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rsc_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rsc_index_ = (threadblock_offset.column() + thread_coord.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + if (kUseFastDivmodPrologue) { + int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); + offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); + } + else { + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; + + int r = 0; + int s = 0; + int c = 0; + + if (kUseFastDivmodMainloop) { + int rs_index = params_.divmod_C.divmod(c, rsc_index); + r = params_.divmod_S.divmod(s, rs_index); + } + else { + c = (rsc_index % problem_size_.C); + + int rs_index = (rsc_index / problem_size_.C); + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int stride_h = kStrideH; + if (!kStrideH) { + stride_h = problem_size_.stride_h; + } + + int stride_w = kStrideW; + if (!kStrideW) { + stride_w = problem_size_.stride_w; + } + + int dilation_h = kDilationH; + if (!kDilationH) { + dilation_h = problem_size_.dilation_h; + } + + int dilation_w = kDilationW; + if (!kDilationW) { + dilation_w = problem_size_.dilation_w; + } + + int h = p * stride_h - problem_size_.pad_h + r * dilation_h; + int w = q * stride_w - problem_size_.pad_w + s * dilation_w; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + bool in_bounds = + coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + + return in_bounds; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + + coord.c(); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFewChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (kDilationH && problem_size.dilation_h != kDilationH) { + return Status::kErrorInvalidProblem; + } + + if (kDilationW && problem_size.dilation_w != kDilationW) { + return Status::kErrorInvalidProblem; + } + + if (kStrideH && problem_size.stride_h != kStrideH) { + return Status::kErrorInvalidProblem; + } + + if (kStrideW && problem_size.stride_w != kStrideW) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0e38c285c78ba570506074f40f6bc5cff45a76 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorFixedChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kFilterPositionsPerTile = Shape::kColumn / AccessType::kElements; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static int const kStrideH = 0; + static int const kStrideW = 0; + static int const kDilationH = 0; + static int const kDilationW = 0; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rs_index_; + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFixedChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rs_index_(0) { + + // + // This requires problem_size.C == AccessType::kElements + // + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rs_index_ = (threadblock_offset.column() + thread_coord.contiguous()) / AccessType::kElements; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + if (kUseFastDivmodPrologue) { + int residual = params_.divmod_Q.divmod(offset_q_[s], offset_npq); + offset_n_[s] = params_.divmod_P.divmod(offset_p_[s], residual); + } + else { + offset_n_[s] = offset_npq / (problem_size_.P * problem_size_.Q); + int residual = offset_npq % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int rs_index = rs_index_ + iteration_vector_; + + int r = 0; + int s = 0; + + if (kUseFastDivmodMainloop) { + r = params_.divmod_S.divmod(s, rs_index); + } + else { + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int stride_h = kStrideH; + if (!kStrideH) { + stride_h = problem_size_.stride_h; + } + + int stride_w = kStrideW; + if (!kStrideW) { + stride_w = problem_size_.stride_w; + } + + int dilation_h = kDilationH; + if (!kDilationH) { + dilation_h = problem_size_.dilation_h; + } + + int dilation_w = kDilationW; + if (!kDilationW) { + dilation_w = problem_size_.dilation_w; + } + + int h = p * stride_h - problem_size_.pad_h + r * dilation_h; + int w = q * stride_w - problem_size_.pad_w + s * dilation_w; + + return TensorCoord(n, h, w, 0); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + coord.c(); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorFixedChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C != AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (kDilationH && problem_size.dilation_h != kDilationH) { + return Status::kErrorInvalidProblem; + } + + if (kDilationW && problem_size.dilation_w != kDilationW) { + return Status::kErrorInvalidProblem; + } + + if (kStrideH && problem_size.stride_h != kStrideH) { + return Status::kErrorInvalidProblem; + } + + if (kStrideW && problem_size.stride_w != kStrideW) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..1a5c33e885be7521981e2d4bc5fc35f3b1412ebe --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorNCxHWx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + using Mask = uint64_t; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFpropActivationIteratorOptimizedParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (r, s) + int filter_r_; + int filter_s_; + int filter_c_; + + Index masks_[ThreadMap::Iterations::kStrided][kAccessesPerVector][2]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_c_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_npq / (problem_size_.P * problem_size_.Q); + // int residual = offset_npq % (problem_size_.P * problem_size_.Q); + // + // offset_p[s] = residual / problem_size_.Q; + // offset_q[s] = residual % problem_size_.Q; + // + + int residual; + + params.pq_divmod(offset_n[s], residual, offset_npq); + params.q_divmod(offset_p[s], offset_q[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_p[s], offset_q[s], 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; + + bool pred = (offset_n[s_idx] < problem_size_.N && h >= 0 && h < problem_size_.H); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][0] |= (pred << r); + } + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; + + bool pred = (w >= 0 && w < problem_size_.W); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + masks_[s_idx][v_idx][1] |= (pred << s); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the activations tensor X that is correspoinding to + // output npq and filter position r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int p, int q, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, h, w, filter_c_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + filter_s_ = 0; + ++filter_r_; + + if (filter_r_ < problem_size_.R) { + next_idx = 1; + } + else { + filter_r_ = 0; + next_idx = 2; + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 2) { + filter_c_ += params_.filter_c_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; + } + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][v][0] = clear ? 0 : masks_[s][v][0]; + masks_[s][v][1] = clear ? 0 : masks_[s][v][1]; + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][iteration_vector_][0] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][iteration_vector_][1] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropActivationTileAccessIteratorOptimized &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.C % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.C % 64) { + return Status::kErrorInvalidProblem; + } + } + + // Conv2dFpropActivationTileAccessIteratorOptimized has constraint on filter positions + // due to the number of mask bits. + if (problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..ed200ed3cf030055b3f7ba470748c91c3751fbfe --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, + bool IsDeconv_ = false +> +class Conv2dFpropFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_r_; + int filter_s_; + int filter_c_; + int filter_c_init_; + int crs_cnt_; + int crs_per_group_; + int group_idx_offset_c_; + int channels_per_group_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_c_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + if (kGroupMode == conv::GroupMode::kDepthwise){ + channels_per_group_ = 1; + crs_per_group_ = problem_size_.S * problem_size_.R; + } else { + channels_per_group_ = input_channels / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (output_channels / problem_size_.groups); + } + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + crs_cnt_ = 0; + filter_c_ = filter_c_init_; + if (kGroupMode != conv::GroupMode::kDepthwise) { + // moves to next group + ++group_idx_offset_c_; + } + } else { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + } + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + int c = filter_c_ + iteration_vector_ * AccessType::kElements; + + return TensorCoord(k, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + if (kGroupMode == conv::GroupMode::kNone) { + return coord.n() < output_channels && coord.c() < input_channels; + } else if (kGroupMode == conv::GroupMode::kDepthwise) { + return coord.n() < output_channels && coord.c() < 1; // channels_per_group_ is always equal to ONE. + } else { + return coord.n() < output_channels && coord.c() < channels_per_group_ && + group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; + } + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + + // check alignment constraint on iterator's contiguous dimension + if ((input_channels / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (output_channels % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (output_channels % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h new file mode 100644 index 0000000000000000000000000000000000000000..f208c9a5bb2ee697626a8caebc5715073ecdc7eb --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropFilterTileAccessIteratorFewChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFewChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kPositionsPerTile = Shape::kRow; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rsc_index_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFewChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rsc_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rsc_index_ = (threadblock_offset.row() + thread_coord.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + rsc_index_ += kPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int rsc_index = rsc_index_ + iteration_vector_ * AccessType::kElements; + + int c = 0; + int s = 0; + int r = 0; + + if (kUseFastDivmodMainloop) { + int rs_index = params_.divmod_C.divmod(c, rsc_index); + r = params_.divmod_S.divmod(s, rs_index); + } + else { + c = (rsc_index % problem_size_.C); + int rs_index = (rsc_index / problem_size_.C); + + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, r, s, c); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + bool in_bounds = + coord.n() < problem_size_.K && + coord.h() >= 0 && + coord.h() < problem_size_.R && + coord.c() < problem_size_.C; + + return in_bounds; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + + coord.c(); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFewChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h new file mode 100644 index 0000000000000000000000000000000000000000..2dc2151d8ba2759d55f6602024a5072b31789cf6 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h @@ -0,0 +1,275 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dFpropFilterTileAccessIteratorFixedChannels { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kFixedChannels; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kFilterPositionsPerTile = Shape::kRow / AccessType::kElements; + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static bool const kUseFastDivmodPrologue = true; + static bool const kUseFastDivmodMainloop = true; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv2dFewChannelsParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int rs_index_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFixedChannels( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + rs_index_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + rs_index_ = (threadblock_offset.row() + thread_coord.contiguous()) / AccessType::kElements; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + rs_index_ += kFilterPositionsPerTile * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int rs_index = rs_index_ + iteration_vector_; + + int r = 0; + int s = 0; + + if (kUseFastDivmodMainloop) { + r = params_.divmod_S.divmod(s, rs_index); + } + else { + s = (rs_index % problem_size_.S); + r = (rs_index / problem_size_.S); + } + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, r, s, 0); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.h() >= 0 && coord.h() < problem_size_.R; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + + int32_t offset = + coord.n() * params_.stride_n + + coord.h() * params_.stride_h + + coord.w() * params_.stride_w + coord.c(); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorFixedChannels &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C != AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (problem_size.K % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (problem_size.K % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..9b12fbe3390c61f9f39ed54ad27cf78e65d80dff --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray, + bool IsDeconv_ = false +> +class Conv2dFpropFilterTileAccessIteratorOptimized{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params : Conv2dFpropFilterIteratorOptimizedParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv2dFpropFilterIteratorOptimizedParams const &base): + Conv2dFpropFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + Conv2dFpropFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { + + } + }; + +private: + + Conv2dFpropFilterIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_rs_; + int filter_c_; + int channels_per_group_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorOptimized( + Conv2dFpropFilterIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_rs_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + Index column = threadblock_offset.column() + thread_coord.strided(); + channels_per_group_ = (IsDeconv ? problem_size_.K : problem_size_.C) / problem_size_.groups; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + predicates_[v_idx] |= (pred << s); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); + } + + pointer_ += ( + params_.layout({filter_c_, column}) + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_rs; + + // moves to the next tile + ++filter_rs_; + if (filter_rs_ == params_.RS) { + + filter_rs_ = 0; + next = params_.inc_next_c; + filter_c_ += params_.filter_c_delta; + } + + CUTLASS_PRAGMA_UNROLL + for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); + } + + pointer_ += next; + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask(int v, bool clear = true) { + predicates_[v] = clear ? 0u : predicates_[v]; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + return (predicates_[iteration_vector_] & (1u << iteration_strided_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_k; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + + // check alignment constraint on iterator's contiguous dimension + if ((input_channels / problem_size.groups) % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + if (platform::is_same>::value) { + if (output_channels % 32) { + return Status::kErrorInvalidProblem; + } + } + + if (platform::is_same>::value) { + if (output_channels % 64) { + return Status::kErrorInvalidProblem; + } + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h new file mode 100644 index 0000000000000000000000000000000000000000..8a3828fccb00b32d70e215785be3da1d317ed38a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_params.h @@ -0,0 +1,893 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv2d analytic tile iterators +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dAnalyticParams { + + using Layout = Layout_; + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv2dAnalyticParams( + Conv2dProblemSize const &, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): layout(layout) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv2d analytic tile iterators +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFewChannelsParams { + + using Layout = Layout_; + + + int32_t stride_w; + int32_t stride_h; + int32_t stride_n; + + FastDivmod divmod_P; + FastDivmod divmod_Q; + FastDivmod divmod_S; + FastDivmod divmod_C; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFewChannelsParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFewChannelsParams( + Conv2dProblemSize const &problem_size, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): + stride_w(int32_t(layout.stride()[0])), + stride_h(int32_t(layout.stride()[1])), + stride_n(int32_t(layout.stride()[2])), + divmod_P(problem_size.P), + divmod_Q(problem_size.Q), + divmod_S(problem_size.S), + divmod_C(problem_size.C) + { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams +struct Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int tiled_rows_per_filter; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientTileAccessIteratorAnalyticParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape + ): layout(layout) { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED + +CUTLASS_HOST_DEVICE +void TraceIteratorParams( + char const *conv_operator, + char const *operand, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta +) { + +#if !defined(__CUDA_ARCH__) + + char const *fname = "conv_iterator_params.csv"; + + std::ifstream test(fname); + bool file_exists = test.is_open(); + + if (file_exists) { + test.close(); + } + + std::ofstream trace("conv_iterator_params.csv", std::ofstream::app); + + if (!file_exists) { + trace + << "Operator,Operand,ElementSize,CtaRows,CtaColumns,ThreadCount,AccessSize," + << "IterationsContiguous,IterationsStrided,DeltaContiguous,DeltaStrided\n"; + } + + trace << conv_operator << "," << operand << "," << element_size_bits << "," + << threadblock_shape.row() << "," << threadblock_shape.column() + << "," << thread_count << "," << access_size + << "," << threadmap_iterations.contiguous() << "," << threadmap_iterations.strided() + << "," << threadmap_delta.contiguous() << "," << threadmap_delta.strided() << "\n"; +#endif +} + +#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) \ + TraceIteratorParams(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta); + +#else + +#define TRACE_CONV_INITIALIZERS(conv_op, operand, element_size, cta_shape, thread_count, access_size, iterations, delta) {} + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFpropActivationIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template<> +struct Conv2dFpropActivationIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int PQ; // product of P*Q + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + PQ(problem_size.P * problem_size.Q), + pq_divmod(PQ), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } + +#if ENABLE_CONV2D_PARAMS_PRINT + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropActivationIteratorOptimizedParams:\n" + " layout(w: %d, h: %d, n: %d)\n" + " inc_next[%ld, %ld, %ld]\n" + " filter_c_delta(%d) - PQ(%d)\n" + " pq_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n" + " q_divmod(divisor: %d, multiplier: %u, shift_right: %u)\n", + stride[0], stride[1], stride[2], + inc_next[0], inc_next[1], inc_next[2], + filter_c_delta, + PQ, + pq_divmod.divisor, + pq_divmod.multiplier, + pq_divmod.shift_right, + q_divmod.divisor, + q_divmod.multiplier, + q_divmod.shift_right + ); + } +#endif +}; + +/// Parameters structure used for Conv2dFpropActivationTileIteratorOptimized +template +struct Conv2dFpropActivationIteratorOptimizedParams> { + static int const kInterleaved = Interleaved_; + + using Layout = layout::TensorNCxHWx; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int PQ; // product of P*Q + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), PQ(problem_size.P * problem_size.Q), pq_divmod(PQ), q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * (kInterleaved * problem_size.dilation_w) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_h + - (problem_size.S - 1) * kInterleaved * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[1]) + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[0] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * kInterleaved * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename Layout_ = layout::TensorNHWC > +struct Conv2dFpropFilterIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Conv2dFpropFilterIteratorOptimizedParams +{ + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + RS = problem_size.R * problem_size.S; + + inc_next_k = (int64_t(layout.stride()[2]) * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( int64_t(layout.stride()[0]) + - int64_t(layout.stride()[2]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices + - int64_t(RS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } + +#if ENABLE_CONV2D_PARAMS_PRINT + /// Prints internal state. + CUTLASS_HOST_DEVICE + void print() { + auto stride = layout.stride(); + printf( + "Conv2dFpropFilterIteratorOptimizedParams:\n" + " layout[%d, %d, %d]\n" + " RS(%d), filter_c_delta(%d), inc_next(k: %ld, rs: %ld, c: %ld)\n", + stride[0], stride[1], stride[2], + RS, + filter_c_delta, + inc_next_k, inc_next_rs, inc_next_c + ); + } +#endif +}; + +template +struct Conv2dFpropFilterIteratorOptimizedParams> +{ + static int const kInterleaved = Interleaved_; + using Layout = layout::TensorCxRSKx; + + Layout layout; + int RS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dFpropFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv2d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + RS = problem_size.R * problem_size.S; + + inc_next_k = (kInterleaved * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( int64_t(layout.stride()[0]) + - kInterleaved * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices / kInterleaved * int64_t(layout.stride()[2]) + - int64_t(RS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * kInterleaved + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Dgrad Optimized Dy params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters object for Conv2d DGRAD OutputGradient (dy) iterator +struct Conv2dDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next K} + + int filter_k_delta; // number of logical elements to add to filter_k_ + + int HW; // product of H*W + + FastDivmod hw_divmod; + FastDivmod w_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + HW(problem_size.H *problem_size.W), + hw_divmod(HW), + w_divmod(problem_size.W) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + (int64_t)layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next K + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h + - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad Optimized Dy params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int64_t inc_next[3]; // {next S, next R, next K} + + int filter_k_delta; // number of logical elements to add to filter_k_ + + int tiled_rows_per_filter; + + int conv_sign; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape + ): layout(layout) { + + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, threadblock_shape.row()); + + tiled_rows_per_filter = tile_m_per_filter * threadblock_shape.row(); + + conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + (int64_t)layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + (int64_t)layout.stride()[1] * problem_size.dilation_h + ) * element_size_bits / 8; + + // next K + inc_next[2] = ( + threadblock_shape.column() * problem_size.split_k_slices + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +// Dgrad Optimized w params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next_rs; // offset in units of bytes to next RS position + int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dDgradFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), RS(problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_rs = + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + ) * element_size_bits / 8; + + inc_next_k = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + ) * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////// +// StridedDgrad Optimized w params (layout::TensorNHWC) +///////////////////////////////////////////////////////////////////////////////////////////////// +struct Conv2dStridedDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + int RS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next[3]; // {next S, next R, next K} + int64_t reset_bytes; // offset in units of bytes to move back the pointer + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dStridedDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dStridedDgradFilterIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), RS(problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + + // next S + inc_next[0] = + ( (int64_t)layout.stride()[0] * problem_size.stride_w + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // next R + inc_next[1] = + ( (int64_t)layout.stride()[1] * problem_size.stride_h + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // next K + inc_next[2] = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] + //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ) * element_size_bits / 8; + + // offset in units of bytes to move the pointer in backward direction + reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] + * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv2d WGRAD Output Gradient (dy) iterator +struct Conv2dWgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int NPQ; // precomputd product of N*P*Q for clearing predicates + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int64_t offset_next_strided; // offset in units of bytes to next npq coordinate within tile + int64_t offset_next_contiguous; // offset in units of bytes to next k coordinate within tile + int64_t inc_next_npq; // offset in units of bytes to next npq position in subsequent tile + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + NPQ(problem_size.N * problem_size.P * problem_size.Q), + pq_divmod(problem_size.P * problem_size.Q), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv2d_wgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + offset_next_contiguous = (threadmap_delta.contiguous()) + * element_size_bits / 8; + + inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + } +}; + +struct Conv2dWgradActivationIteratorOptimizedParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod sc_divmod; + FastDivmod pq_divmod; + FastDivmod q_divmod; + FastDivmod c_divmod; + FastDivmod s_divmod; + int small_channel_conv_s_offset; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout + ): + layout(layout), + sc_divmod(problem_size.S * problem_size.C), + pq_divmod(problem_size.P * problem_size.Q), + q_divmod(problem_size.Q), + c_divmod(problem_size.C), + s_divmod(problem_size.S * problem_size.dilation_w), + small_channel_conv_s_offset((problem_size.S - 1) * problem_size.dilation_w - problem_size.pad_w) { + } + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationIteratorOptimizedParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + Conv2dWgradActivationIteratorOptimizedParams( + problem_size, + layout + ) { + + TRACE_CONV_INITIALIZERS("conv2d_wgrad", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + } +}; + +struct PredicatedScaleBiasVectorAccessIteratorParams { + public: + /// Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams() { } + + // Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams( + Conv2dProblemSize const &problem_size, + layout::PitchLinear const &layout) {} + + // Default ctor + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIteratorParams( + Conv2dProblemSize const &problem_size, + layout::RowMajor const &layout) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..13bd29b7a0547eee204d642a3cd67a24709f89ea --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -0,0 +1,337 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template wraps the tile access iterator concept to load whole tiles from tensors in + memory used for implicit GEMM convolution. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TileIterator { +public: + using TileAccessIterator = TileAccessIterator_; + + using Shape = typename TileAccessIterator::Shape; + using Element = typename TileAccessIterator::Element; + using Layout = typename TileAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = typename TileAccessIterator::ThreadMap; + using AccessType = typename TileAccessIterator::AccessType; + using TensorRef = typename TileAccessIterator::TensorRef; + using Index = typename TileAccessIterator::Index; + using LongIndex = typename TileAccessIterator::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; + using Params = typename TileAccessIterator::Params; + static int const kConvDim = TileAccessIterator::kConvDim; + using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + +private: + + /// Internal state + TileAccessIterator tile_access_iterator_; + +public: + + /// Constructor + CUTLASS_HOST_DEVICE + TileIterator( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_(params, problem_size, ptr, thread_idx, threadblock_offset) { } + + CUTLASS_HOST_DEVICE + static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { + return TileAccessIterator::getParams(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + tile_access_iterator_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + tile_access_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIterator &operator++() { + tile_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIterator operator++(int) { + TileIterator self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[idx], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); + + ++tile_access_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + tile_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + tile_access_iterator_.advance(); + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // dispatch to iterator implementation + return TileAccessIterator::can_implement(problem_size); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad Tile Iterator +template +class TileIteratorStridedDgrad { +public: + using TileAccessIterator = TileAccessIterator_; + + using Shape = typename TileAccessIterator::Shape; + using Element = typename TileAccessIterator::Element; + using Layout = typename TileAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = typename TileAccessIterator::ThreadMap; + using AccessType = typename TileAccessIterator::AccessType; + using TensorRef = typename TileAccessIterator::TensorRef; + using Index = typename TileAccessIterator::Index; + using LongIndex = typename TileAccessIterator::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = TileAccessIterator::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = TileAccessIterator::kStrideSupport; + using Params = typename TileAccessIterator::Params; + static int const kConvDim = TileAccessIterator::kConvDim; + using ConvProblemSize = typename TileAccessIterator::ConvProblemSize; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + +private: + + /// Internal state + TileAccessIterator tile_access_iterator_; + +public: + + /// Constructor (output gradient (Dy) OperandA ctor) + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + FastDivmod const &stride_h_divmod, FastDivmod const &stride_w_divmod, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_( + params, + problem_size, + ptr, + thread_idx, + stride_h_divmod, stride_w_divmod, + start_r, start_s, + threadblock_offset) { } + + /// Constructor (filter (w) OperandB ctor) + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + int start_r, int start_s, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + tile_access_iterator_(params, + problem_size, + ptr, + thread_idx, + start_r, start_s, + threadblock_offset) { } + + CUTLASS_HOST_DEVICE + static Params getParams(ConvProblemSize const &problem_size, Layout const &layout) { + return TileAccessIterator::getParams(problem_size, layout); + } + + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + tile_access_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad &operator++() { + tile_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + TileIteratorStridedDgrad operator++(int) { + TileIteratorStridedDgrad self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[c + s * ThreadMap::Iterations::kContiguous], + tile_access_iterator_.get() + pointer_offset, + tile_access_iterator_.valid() + ); + + ++tile_access_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + tile_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + tile_access_iterator_.advance(); + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // dispatch to iterator implementation + return TileAccessIterator::can_implement(problem_size); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..b5a240773b5912c9ace50916f55e8a1054092845 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // Filter postion (r,s,c) in contiguous dimension stays constant for each gemm_iteration_k + int filter_r_[ThreadMap::Iterations::kContiguous]; + int filter_s_[ThreadMap::Iterations::kContiguous]; + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) + { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); + int residual = rsc_offset % (problem_size_.S * problem_size_.C); + + filter_s_[c] = residual / problem_size_.C; + filter_c_[c] = residual % problem_size_.C; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int r, s, c; + + if (kAccessesPerVector == 1) { + /// One 128b aligned access fetching more than one element + c = filter_c_[iteration_contiguous_]; + r = filter_r_[iteration_contiguous_]; + s = filter_s_[iteration_contiguous_]; + } + else { + /// Multiple access to support non-128b alignment in contiguous dimension + c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; + int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; + s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; + int wrap_s = (filter_s_[iteration_contiguous_] + wrap_c) / problem_size_.S; + r = filter_r_[iteration_contiguous_] + wrap_s; + } + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); + int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..56197279a5a45be6ac992ac16a86011cd9843646 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -0,0 +1,321 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dWgradActivationIteratorOptimizedParams; + +private: + + Conv2dWgradActivationIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + // Precomputed effective filter postion (r,s) in contiguous dimension stays constant for each gemm_iteration_k + // required for npq -> nhw translation + int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorOptimized( + Conv2dWgradActivationIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) + { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int rsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // filter_r_[c] = rsc_offset / (problem_size_.S * problem_size_.C); + // int residual = rsc_offset % (problem_size_.S * problem_size_.C); + // + // filter_s_[c] = residual / problem_size_.C; + // filter_c_[c] = residual % problem_size_.C; + + int residual; + params_.sc_divmod(precomputed_filter_r_[c], residual, rsc_offset); + params_.c_divmod(precomputed_filter_s_[c], filter_c_[c], residual); + + int r = precomputed_filter_r_[c]; + int s = precomputed_filter_s_[c]; + + if (problem_size_.mode == Mode::kConvolution) { + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + precomputed_filter_r_[c] = -problem_size_.pad_h + r * problem_size_.dilation_h; + precomputed_filter_s_[c] = -problem_size_.pad_w + s * problem_size_.dilation_w; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_npq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_npq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int r = precomputed_filter_r_[iteration_contiguous_]; + int s = precomputed_filter_s_[iteration_contiguous_]; + int c = filter_c_[iteration_contiguous_]; + + if (kAccessesPerVector > 1) { + // This code section is only to support non-128b alignment + // Multiple access to support non-128b alignment in contiguous dimension + int wrap_c; + params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); + + if (problem_size_.mode == Mode::kConvolution) { + s -= (problem_size_.dilation_w * wrap_c); + + int wrap_s; + params_.s_divmod(wrap_s, s, params_.small_channel_conv_s_offset - s); + s = params_.small_channel_conv_s_offset - s; + + r -= (problem_size_.dilation_h * wrap_s); + + } else { + s += (problem_size_.dilation_w * wrap_c); + + int wrap_s; + params_.s_divmod(wrap_s, s, s + problem_size_.pad_w); + s -= problem_size_.pad_w; + + r += (problem_size_.dilation_h * wrap_s); + } + } + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int n = offset_npq_[iteration_strided_] / (problem_size_.P * problem_size_.Q); + // int residual = offset_npq_[iteration_strided_] % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, p, q; + + params_.pq_divmod(n, residual, offset_npq_[iteration_strided_]); + params_.q_divmod(p, q, residual); + + int h = p * problem_size_.stride_h + r; + int w = q * problem_size_.stride_w + s; + + return TensorCoord(n, h, w, c); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradActivationTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..ea48bc6de0f94015b24632b6609ee8c81dac93cb --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,260 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dAnalyticParams; + +private: + + Params const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_[ThreadMap::Iterations::kContiguous]; + + int offset_npq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize filter_k for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] = threadblock_offset.column() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_npq_[s] += Shape::kColumn * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int npq = offset_npq_[iteration_strided_]; + + int n = npq / (problem_size_.P * problem_size_.Q); + int residual = npq % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int k = filter_k_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements; + + return TensorCoord(n, p, q, k); + } + + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.h() < problem_size_.P && + coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..8e5048fd304f8edd96cca25bc8735725d6f2e843 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + typename AccessType_ = cutlass::AlignedArray +> +class Conv2dWgradOutputGradientTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNHWC; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + using Params = Conv2dWgradOutputGradientIteratorOptimizedParams; + +private: + + Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms_; + Conv2dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + uint32_t predicates_[kAccessesPerVector]; + int filter_k_; + int offset_npq_; + +public: + + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorOptimized( + Conv2dWgradOutputGradientIteratorOptimizedParams const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_k_(0), + offset_npq_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); + offset_npq_ = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; + int offset_npq = offset_npq_ + s * ThreadMap::Delta::kStrided; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + bool predicate = valid_(at_(offset_npq, filter_k + v * AccessType::kElements)); + + uint32_t pred = (predicate ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_[v] |= (pred << pred_idx); + } + } + } + + // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) + pointer_ += ( + offset_npq_ * params.layout.stride()[0] + filter_k_ + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + offset_npq_ += Shape::kColumn * problem_size_.split_k_slices; + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (offset_npq_ + s * ThreadMap::Delta::kStrided >= params_.NPQ) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + predicates_[v] = (predicates_[v] & (~kClearMask)); + } + } + } + + pointer_ += params_.inc_next_npq; + } + +private: + /// Returns the coordinate in the output gradient tensor Dy that is pointed to + /// by offset_npq and k. + CUTLASS_HOST_DEVICE + TensorCoord at_(int offset_npq, int k) const { + + // The subsequent fast_divmod() operations are equivalent to the following logical computation: + // + // + // int npq = offset_npq; + // int n = npq / (problem_size_.P * problem_size_.Q); + // int residual = npq % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, p, q; + + params_.pq_divmod(n, residual, offset_npq); + params_.q_divmod(p, q, residual); + + return TensorCoord(n, p, q, k); + } + + /// Returns true if the coord is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid_(TensorCoord coord) const { + + return coord.n() < problem_size_.N && + coord.c() < problem_size_.K; + } + +public: + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_[iteration_vector_] & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast( + pointer_ + + iteration_strided_ * params_.offset_next_strided + + iteration_contiguous_ * params_.offset_next_contiguous + ) + iteration_vector_; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv2dWgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..d996003f42587ef6de2af268e1903808991b34d9 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -0,0 +1,268 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dDgradFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or larger."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // For a fixed filter position (t,r,s) find and fill offset_k_, offset_c_ in strided and contiguous dimension + int filter_t_; + int filter_r_; + int filter_s_; + int offset_k_[ThreadMap::Iterations::kStrided]; + int offset_c_[ThreadMap::Iterations::kContiguous]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + offset_c_[c] = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = + threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the filter tensor w that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = offset_c_[iteration_contiguous_]; + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, filter_t_, filter_r_, filter_s_, c); + } + + /// Returns true if the current coordinate is within the filter tensor w + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..a269b18b0010329dedd31c4689bff8db4fb46d2a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity +> +class Conv3dDgradFilterTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = StrideSupport_; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Parameters structure + // + + struct Params : Conv3dDgradFilterIteratorOptimizedParams { + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv3dDgradFilterIteratorOptimizedParams const &base): + Conv3dDgradFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): + Conv3dDgradFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { } + + }; + +private: + + Conv3dDgradFilterIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_trs_; + int filter_k_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided * + ThreadMap::Iterations::kContiguous < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorOptimized( + Conv3dDgradFilterIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_(0), + filter_trs_(0), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.strided(); + Index column = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + s * ThreadMap::Delta::kStrided; + int filter_c = column + c * ThreadMap::Delta::kContiguous; + + uint32_t pred = ((filter_k < problem_size_.K && filter_c < problem_size_.C) ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_ |= (pred << pred_idx); + } + } + + pointer_ += ( + filter_k_ * params.layout.stride()[3] + column + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_trs; + + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == params_.TRS) { + + filter_trs_ = 0; + next = params_.inc_next_k; + filter_k_ += params_.filter_k_delta; + } + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (filter_k_ + s * ThreadMap::Delta::kStrided >= problem_size_.K) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + + predicates_ = (predicates_ & (~kClearMask)); + } + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_ & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradFilterTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_strided; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..700c3d12ddfd53b0acef8b6c11188499ca021f76 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,343 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kStrided +> +class Conv3dDgradOutputGradientTileAccessIteratorAnalytic; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv3dDgradOutputGradientTileAccessIteratorAnalytic strided dgrad needs special handling using +// unscaled coordinations +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < + Shape_, + Element_, + ThreadMap_, + conv::StrideSupport::kStrided +> { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "DGRAD requires elements of size 8b or greater."); + + // + // Simpligying assertions + // + + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + ConvProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_k_; + int filter_t_; + int filter_r_; + int filter_s_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_d_[ThreadMap::Iterations::kStrided]; + int offset_w_[ThreadMap::Iterations::kStrided]; + int offset_h_[ThreadMap::Iterations::kStrided]; + +private: + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator but DOES NOT scale by the convolution stride. This is needed + /// to compute predicates in the valid() method. The return value of the public at() + /// method is correctly scaled. + CUTLASS_HOST_DEVICE + TensorCoord unscaled_at_() const { + int n = offset_n_[iteration_strided_]; + int d = offset_d_[iteration_strided_]; + int h = offset_h_[iteration_strided_]; + int w = offset_w_[iteration_strided_]; + + int t = filter_t_; + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int z = (d + problem_size_.pad_d - t * problem_size_.dilation_d); + int p = (h + problem_size_.pad_h - r * problem_size_.dilation_h); + int q = (w + problem_size_.pad_w - s * problem_size_.dilation_w); + + return TensorCoord(n, z, p, q, filter_k_); + } + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // threadblock offset - units are whole CTA tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); + int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); + + offset_d_[s] = residual / (problem_size_.H * problem_size_.W); + residual = residual % (problem_size_.H * problem_size_.W); + + offset_h_[s] = residual / problem_size_.W; + offset_w_[s] = residual % problem_size_.W; + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // move to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_k_ += Shape_::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the output tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + TensorCoord coord = unscaled_at_(); + + return TensorCoord( + coord.n(), + coord.d() / problem_size_.stride_d, + coord.h() / problem_size_.stride_h, + coord.w() / problem_size_.stride_w, + coord.c()); + } + + + /// Returns true if the current coordinate is within the output tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord unscaled_coord = unscaled_at_(); + TensorCoord coord = at(); + + return + !(unscaled_coord.d() % problem_size_.stride_d) && + !(unscaled_coord.h() % problem_size_.stride_h) && + !(unscaled_coord.w() % problem_size_.stride_w) && + coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.Z && + coord.h() >= 0 && coord.h() < problem_size_.P && + coord.w() >= 0 && coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..69915babcbfcacc1a1830a4f9d70885aca5d40c8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,489 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + conv::StrideSupport StrideSupport_ = conv::StrideSupport::kUnity +> +class Conv3dDgradOutputGradientTileAccessIteratorOptimized { +public: + + static_assert(StrideSupport_ == conv::StrideSupport::kUnity, + "Only unit-stride dgrad is supported at this time."); + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kUnity; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + using Coord3D = Coord<3>; + static int const kAccessesPerVector = 1; + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dDgradOutputGradientIteratorOptimizedParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (t, r, s) + int filter_t_; + int filter_r_; + int filter_s_; + int filter_k_; + + Index masks_[ThreadMap::Iterations::kStrided][3]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + filter_k_(0), + filter_t_(0), + filter_r_(0), + filter_s_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_d[ThreadMap::Iterations::kStrided]; + int offset_h[ThreadMap::Iterations::kStrided]; + int offset_w[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_ndhw = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_ndhw / (problem_size_.D * problem_size_.H * problem_size_.W); + // int residual = offset_ndhw % (problem_size_.D * problem_size_.H * problem_size_.W); + // + // + // offset_d[s] = residual / (problem_size_.H * problem_size_.W); + // residual = residual % (problem_size_.H * problem_size_.W); + // + // offset_h[s] = residual / problem_size_.W; + // offset_w[s] = residual % problem_size_.W; + // + + int residual; + + // input: (ndhw offset) output: (n offset and resudial (dhw offset)) + params_.dhw_divmod(offset_n[s], residual, offset_ndhw); + // input: (dhw offset) output: (d offset and resudial (hw)) + params_.hw_divmod(offset_d[s], residual, residual); + // input: (hw offset) output: (h offset and resudial (w offset)) + params_.w_divmod(offset_h[s], offset_w[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_d[s], offset_h[s], offset_w[s], 0, 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int t = 0; t < problem_size_.T; ++t) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int t_ = t; + if (problem_size_.mode == Mode::kConvolution) { + t_ = problem_size_.T - 1 - t; + } + + int z = offset_d[s_idx] + problem_size_.pad_d - t_ * problem_size_.dilation_d; + + bool pred = (offset_n[s_idx] < problem_size_.N && z >= 0 && z < problem_size_.Z); + masks_[s_idx][0] |= (pred << t); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int p = offset_h[s_idx] + problem_size_.pad_h - r_ * problem_size_.dilation_h; + + bool pred = (p >= 0 && p < problem_size_.P); + masks_[s_idx][1] |= (pred << r); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int q = offset_w[s_idx] + problem_size_.pad_w - s_ * problem_size_.dilation_w; + + bool pred = (q >= 0 && q < problem_size_.Q); + masks_[s_idx][2] |= (pred << s); + } + } + + if (filter_k_ >= problem_size.K) { + clear_mask(); + } + + set_iteration_index(0); + + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + + /// Returns the coordinate in the output gradient tensor dy that is correspoinding to + // activation ndhw and filter position k, t, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int d, int h, int w, int t, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + t = problem_size_.T - 1 - t; + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int z = d + problem_size_.pad_d - t * problem_size_.dilation_d; + int p = h + problem_size_.pad_h - r * problem_size_.dilation_h; + int q = w + problem_size_.pad_w - s * problem_size_.dilation_w; + + return TensorCoord(n, z, p, q, filter_k_); + } + + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][0]) + : + "r"((int)clear), + "r"(masks_[s][0]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][1]) + : + "r"((int)clear), + "r"(masks_[s][1]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][2]) + : + "r"((int)clear), + "r"(masks_[s][2]) + ); + #else + if (clear) { + masks_[s][0] = 0; + masks_[s][1] = 0; + masks_[s][2] = 0; + } + #endif + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + + filter_s_ = 0; + ++filter_r_; + next_idx = 1; + + if (filter_r_ == problem_size_.R) { + filter_r_ = 0; + ++filter_t_; + + if (filter_t_ < problem_size_.T) { + next_idx = 2; + } + else { + filter_t_ = 0; + next_idx = 3; + } + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 3) { + filter_k_ += params_.filter_k_delta; + } + + clear_mask_(filter_k_ >= problem_size_.K); + } + + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask() { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][0] = Mask(0); + masks_[s][1] = Mask(0); + masks_[s][2] = Mask(0); + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && + (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientTileAccessIteratorOptimized &operator++() { + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // This is specialized for unit stride + if (problem_size.stride() != Coord3D({1, 1, 1})) { + return Status::kErrorNotSupported; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorNotSupported; + } + + // Limit on filter size + if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..5a888e0fe4e63a255cd8bdb6b27de831691f71c8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dFpropActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dAnalyticParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + int offset_n_[ThreadMap::Iterations::kStrided]; + int offset_z_[ThreadMap::Iterations::kStrided]; + int offset_p_[ThreadMap::Iterations::kStrided]; + int offset_q_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + offset_n_[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + offset_z_[s] = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + offset_p_[s] = residual / problem_size_.Q; + offset_q_[s] = residual % problem_size_.Q; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int n = offset_n_[iteration_strided_]; + int z = offset_z_[iteration_strided_]; + int p = offset_p_[iteration_strided_]; + int q = offset_q_[iteration_strided_]; + + int t = filter_t_; + int r = filter_r_; + int s = filter_s_; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - filter_t_); + r = (problem_size_.R - 1 - filter_r_); + s = (problem_size_.S - 1 - filter_s_); + } + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..057023c09cb73199bda94f62c3c879269d7b5189 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -0,0 +1,478 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_ +> +class Conv3dFpropActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + using Mask = uint64_t; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dFpropActivationIteratorOptimizedParams; + +private: + + Conv3dFpropActivationIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + + // One pointer per access + char const *pointer_[ThreadMap::Iterations::kStrided]; + + // current filter position (t, r, s) + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + // mask for t, r, and s + Index masks_[ThreadMap::Iterations::kStrided][3]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorOptimized( + Conv3dFpropActivationIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() // tile index - units are threadblock-scoped tiles + ) : + params_(params), + problem_size_(problem_size), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + + int offset_n[ThreadMap::Iterations::kStrided]; + int offset_z[ThreadMap::Iterations::kStrided]; + int offset_p[ThreadMap::Iterations::kStrided]; + int offset_q[ThreadMap::Iterations::kStrided]; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + pointer_[s] = reinterpret_cast(ptr); + + int offset_nzpq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // offset_n[s] = offset_nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = offset_nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // offset_z[s] = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // offset_p[s] = residual / problem_size_.Q; + // offset_q[s] = residual % problem_size_.Q; + // + + int residual; + + // input: (nzpq offset) output: (n offset and resudial (zpq offset)) + params.zpq_divmod(offset_n[s], residual, offset_nzpq); + // input: (zpq offset) output: (z offset and resudial (pq)) + params.pq_divmod(offset_z[s], residual, residual); + // input: (pq offset) output: (p offset and resudial (q offset)) + params.q_divmod(offset_p[s], offset_q[s], residual); + + TensorCoord coord = at_(offset_n[s], offset_z[s], offset_p[s], offset_q[s], 0, 0, 0); + + pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; + } + + clear_mask(); + + // mask predicates for filter position T + CUTLASS_PRAGMA_NO_UNROLL + for (int t = 0; t < problem_size_.T; ++t) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int t_ = t; + if (problem_size_.mode == Mode::kConvolution) { + t_ = problem_size_.T - 1 - t; + } + + int d = offset_z[s_idx] * problem_size_.stride_d - problem_size_.pad_d + t_ * problem_size_.dilation_d; + + bool pred = (offset_n[s_idx] < problem_size_.N && d >= 0 && d < problem_size_.D); + masks_[s_idx][0] |= (pred << t); + } + } + + // mask predicates for filter position R + CUTLASS_PRAGMA_NO_UNROLL + for (int r = 0; r < problem_size_.R; ++r) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int r_ = r; + if (problem_size_.mode == Mode::kConvolution) { + r_ = problem_size_.R - 1 - r; + } + + int h = offset_p[s_idx] * problem_size_.stride_h - problem_size_.pad_h + r_ * problem_size_.dilation_h; + + bool pred = (h >= 0 && h < problem_size_.H); + masks_[s_idx][1] |= (pred << r); + } + } + + // mask predicates for filter position S + CUTLASS_PRAGMA_NO_UNROLL + for (int s = 0; s < problem_size_.S; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int s_idx = 0; s_idx < ThreadMap::Iterations::kStrided; ++s_idx) { + + int s_ = s; + if (problem_size_.mode == Mode::kConvolution) { + s_ = problem_size_.S - 1 - s; + } + + int w = offset_q[s_idx] * problem_size_.stride_w - problem_size_.pad_w + s_ * problem_size_.dilation_w; + + bool pred = (w >= 0 && w < problem_size_.W); + masks_[s_idx][2] |= (pred << s); + } + } + + if (filter_c_ >= problem_size.C) { + clear_mask(); + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}); + } + +private: + + /// Returns the coordinate in the activations tensor X that is correspoinding to + // output nzpq and filter position t, r, s + CUTLASS_HOST_DEVICE + TensorCoord at_(int n, int z, int p, int q, int t, int r, int s) const { + + if (problem_size_.mode == Mode::kConvolution) { + t = problem_size_.T - 1 - t; + r = problem_size_.R - 1 - r; + s = problem_size_.S - 1 - s; + } + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_); + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_byte_offset_(LongIndex byte_offset) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + pointer_[s] += byte_offset; + } + } + + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask_(bool clear) { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + // We are using inline PTX assembly here to avoid an CUDA C++ compilation + // artifact in which control flow instructions are generated. Instead, our + // intent is to predicate the mov instructions. + #if defined(__CUDA_ARCH__) + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][0]) + : + "r"((int)clear), + "r"(masks_[s][0]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][1]) + : + "r"((int)clear), + "r"(masks_[s][1]) + ); + asm volatile( + "{\n" + " .reg .pred p;\n" + " .reg .u32 m;" + " mov.u32 m, %2;" + " setp.ne.b32 p, %1, 0;\n" + " @p mov.u32 m, 0;\n" + " mov.u32 %0, m;\n" + "}\n" + : + "=r"(masks_[s][2]) + : + "r"((int)clear), + "r"(masks_[s][2]) + ); + #else + if (clear) { + masks_[s][0] = 0; + masks_[s][1] = 0; + masks_[s][2] = 0; + } + #endif + } + } + +public: + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + add_byte_offset_(pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_HOST_DEVICE + void advance() { + + int next_idx = 0; + + // moves to the next tile + ++filter_s_; + if (filter_s_ == problem_size_.S) { + + filter_s_ = 0; + ++filter_r_; + next_idx = 1; + + if (filter_r_ == problem_size_.R) { + filter_r_ = 0; + ++filter_t_; + + if (filter_t_ < problem_size_.T) { + next_idx = 2; + } + else { + filter_t_ = 0; + next_idx = 3; + } + } + } + + add_byte_offset_(params_.inc_next[next_idx]); + + if (next_idx == 3) { + filter_c_ += params_.filter_c_delta; + } + + clear_mask_(filter_c_ >= problem_size_.C); + } + + /// Clears the predicates + CUTLASS_HOST_DEVICE + void clear_mask() { + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + masks_[s][0] = Mask(0); + masks_[s][1] = Mask(0); + masks_[s][2] = Mask(0); + } + } + + CUTLASS_HOST_DEVICE + bool valid() { + + return + (masks_[iteration_strided_][0] & (Index(1) << filter_t_)) && + (masks_[iteration_strided_][1] & (Index(1) << filter_r_)) && + (masks_[iteration_strided_][2] & (Index(1) << filter_s_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast(pointer_[iteration_strided_]); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropActivationTileAccessIteratorOptimized &operator++() { + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // Conv3dFpropActivationTileAccessIteratorOptimized has constraint on filter positions + // due to the number of mask bits. + if (problem_size.T > 32 || problem_size.R > 32 || problem_size.S > 32) { + return Status::kErrorNotSupported; + } + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..4a40d37e56bfc73744f80dc0cf84e30918b86a1b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -0,0 +1,259 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_, + bool IsDeconv_ = false +> +class Conv3dFpropFilterTileAccessIteratorAnalytic { +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + using Params = Conv3dAnalyticParams; + +private: + + Params const ¶ms_; + ConvProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_t_; + int filter_r_; + int filter_s_; + int filter_c_; + + int offset_k_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorAnalytic( + Params const ¶ms, + ConvProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_t_(0), + filter_r_(0), + filter_s_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_s_; + if (filter_s_ < problem_size_.S) { + return; + } + filter_s_ = 0; + + ++filter_r_; + if (filter_r_ < problem_size_.R) { + return; + } + filter_r_ = 0; + + ++filter_t_; + if (filter_t_ < problem_size_.T) { + return; + } + filter_t_ = 0; + + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = offset_k_[iteration_strided_]; + + return TensorCoord(k, filter_t_, filter_r_, filter_s_, filter_c_); + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + return coord.n() < output_channels && + coord.c() < input_channels; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(ConvProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension + if (input_channels % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..b4e7db3a4398b67a2a2cf185cf9e689a22d3d0b8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -0,0 +1,279 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC or TensorCxRSKx layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename Layout_, + typename ThreadMap_, + bool IsDeconv_ = false +> +class Conv3dFpropFilterTileAccessIteratorOptimized{ +public: + + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + + struct Params : Conv3dFpropFilterIteratorOptimizedParams { + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Conv3dFpropFilterIteratorOptimizedParams const &base): + Conv3dFpropFilterIteratorOptimizedParams(base) { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): + Conv3dFpropFilterIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ) { + + } + }; + +private: + + Conv3dFpropFilterIteratorOptimizedParams const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_trs_; + int filter_c_; + + // + // Assertions + // + + // We map predicates into bits packed in this uint32_t container + static_assert(ThreadMap::Iterations::kStrided < sizeof(predicates_) * 8, + "Currently, the number of loads per iteration is limited by the size of the predicates container."); + +public: + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorOptimized( + Conv3dFpropFilterIteratorOptimizedParams const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_{0}, + filter_trs_(0), + filter_c_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + Index column = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); + predicates_ |= (pred << s); + } + + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { + predicates_ = 0u; + } + + pointer_ += ( + params_.layout({filter_c_, column}) + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + LongIndex next = params_.inc_next_trs; + + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == params_.TRS) { + + filter_trs_ = 0; + next = params_.inc_next_c; + filter_c_ += params_.filter_c_delta; + } + + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { + predicates_ = 0; + } + + pointer_ += next; + } + + /// Returns true if the current coordinate is within the filter tensor W + CUTLASS_HOST_DEVICE + bool valid() { + return (predicates_ & (1u << iteration_strided_)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + return reinterpret_cast(pointer_); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dFpropFilterTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + + // Move to the next K coordinate within the tile + pointer_ += params_.inc_next_k; + + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + + // check alignment constraint on iterator's contiguous dimension + if (input_channels % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h new file mode 100644 index 0000000000000000000000000000000000000000..941f4e1dff7ebfffd6830ad76876087b98a0c8b0 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_params.h @@ -0,0 +1,508 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/conv3d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Params structure used for all Conv3d analytic tile iterators +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dAnalyticParams { + + using Layout = Layout_; + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dAnalyticParams() { } + + CUTLASS_HOST_DEVICE + Conv3dAnalyticParams( + Conv3dProblemSize const &, // unused; placeholder to match other Params interfaces. + Layout const &layout + ): layout(layout) { + + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dFpropActivationIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for Conv3dFpropActivationTileIteratorOptimized +template<> +struct Conv3dFpropActivationIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int64_t inc_next[4]; // {next S, next R, next T, next C} + int filter_c_delta; // number of logical elements to add to filter_c_ + int ZPQ; // product of Z*P*Q + int PQ; // product of P*Q + + FastDivmod zpq_divmod; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dFpropActivationIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + PQ(problem_size.P * problem_size.Q), + ZPQ(problem_size.Z * problem_size.P * problem_size.Q), + zpq_divmod(ZPQ), + pq_divmod(PQ), + q_divmod(problem_size.Q) { + + TRACE_CONV_INITIALIZERS("conv3d_fprop", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + + int conv_sign = (problem_size.mode == Mode::kConvolution ? -1 : 1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next T + inc_next[2] = conv_sign * ( + int64_t(layout.stride()[2]) * problem_size.dilation_d + - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next C + inc_next[3] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_c_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template< typename Layout_ = layout::TensorNDHWC > +struct Conv3dFpropFilterIteratorOptimizedParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Conv3dFpropFilterIteratorOptimizedParams +{ + + using Layout = layout::TensorNDHWC; + + Layout layout; + int TRS; + int filter_c_delta; + + int64_t inc_next_k; // offset in units of bytes to next K position + int64_t inc_next_trs; // offset in units of bytes to next TRS position + int64_t inc_next_c; // offset in units of bytes to next C position + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dFpropFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dFpropFilterIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_fprop", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + TRS = problem_size.T * problem_size.R * problem_size.S; + + inc_next_k = (int64_t(layout.stride()[3]) * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_trs = + ( int64_t(layout.stride()[0]) + - int64_t(layout.stride()[3]) * (threadmap_iterations.strided() - 1) * threadmap_delta.strided() + ) * element_size_bits / 8; + + inc_next_c = + ( + threadblock_shape.row() * problem_size.split_k_slices + - int64_t(TRS - 1) * layout.stride()[0] + - int64_t(threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + ) * element_size_bits / 8; + + filter_c_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv3d DGRAD OutputGradient (dy) iterator +struct Conv3dDgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int64_t inc_next[4]; // {next S, next R, next T, next K} + int filter_k_delta; // number of logical elements to add to filter_k_ + + FastDivmod dhw_divmod; + FastDivmod hw_divmod; + FastDivmod w_divmod; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dDgradOutputGradientIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, ///< layout object + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), + dhw_divmod(problem_size.D * problem_size.H * problem_size.W), + hw_divmod(problem_size.H * problem_size.W), + w_divmod(problem_size.W) { + + TRACE_CONV_INITIALIZERS("conv3d_dgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + int conv_sign = (problem_size.mode == Mode::kConvolution ? 1 : -1); + + // next S + inc_next[0] = conv_sign * ( + int64_t(layout.stride()[0]) * problem_size.dilation_w + ) * element_size_bits / 8; + + // next R + inc_next[1] = conv_sign * ( + int64_t(layout.stride()[1]) * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next T + inc_next[2] = conv_sign * ( + int64_t(layout.stride()[2]) * problem_size.dilation_d + - (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // next K + inc_next[3] = ( + threadblock_shape.column() * problem_size.split_k_slices + - conv_sign * int64_t(problem_size.T - 1) * layout.stride()[2] * problem_size.dilation_d + - conv_sign * int64_t(problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h + - conv_sign * int64_t(problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + ) * element_size_bits / 8; + + // logical offset added to internal channel counter - units are elements, not bytes + filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters object for Conv2d DGRAD Filter (w) iterator +struct Conv3dDgradFilterIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + int TRS; + int filter_k_delta; + + int64_t inc_next_strided; // offset in units of bytes to next K coordinate within tile + int64_t inc_next_trs; // offset in units of bytes to next TRS position + int64_t inc_next_k; // offset in units of bytes to next K position in subsequent tile + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dDgradFilterIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dDgradFilterIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, ///< size of each element in bits + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): + layout(layout), TRS(problem_size.T * problem_size.R * problem_size.S) { + + TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; + + inc_next_trs = + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] + ) * element_size_bits / 8; + + inc_next_k = + ( + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] + - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] + ) * element_size_bits / 8; + + filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; + } +}; + +/// Parameters object for Conv3d WGRAD OutputGradient iterator +struct Conv3dWgradOutputGradientIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + using LongIndex = typename Layout::LongIndex; + + Layout layout; + + int NZPQ; // precomputd product of N*Z*P*Q for clearing predicates + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + LongIndex offset_next_strided; // offset in units of bytes to next nzpq coordinate within tile + LongIndex offset_next_contiguous; // offset in units of bytes to next k coordinate within tile + LongIndex inc_next_nzpq; // offset in units of bytes to next nzpq position in subsequent tile + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "output_gradient", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + offset_next_contiguous = (threadmap_delta.contiguous()) + * element_size_bits / 8; + + inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) + * element_size_bits / 8; + + // Precompute several quantities for fast modulo arithmetic. + NZPQ = problem_size.N * problem_size.Z * problem_size.P * problem_size.Q; + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + +/// Parameters object for Conv3d WGRAD Activation Tile Access Iterator +struct Conv3dWgradActivationIteratorOptimizedParams { + + using Layout = layout::TensorNDHWC; + + Layout layout; + + int RSC; // product of R*S*C + unsigned rsc_mul; // precomputed quantities for fast computation of div/% by RSC + unsigned rsc_shr; // in device code. + + int SC; // product of S*C + unsigned sc_mul; // precomputed quantities for fast computation of div/% by SC + unsigned sc_shr; // in device code. + + unsigned c_mul; // precomputed quantities for fast computation of div/% by C + unsigned c_shr; // in device code. + + int ZPQ; // product of Z*P*Q + unsigned zpq_mul; // precomputed quantities for fast computation of div/% by ZPQ + unsigned zpq_shr; // in device code. + + int PQ; // product of P*Q + unsigned pq_mul; // precomputed quantities for fast computation of div/% by PQ + unsigned pq_shr; // in device code. + + unsigned q_mul; // precomputed quantities for fast computation of div/% by Q + unsigned q_shr; // in device code. + + // + // Methods + // + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams() { } + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationIteratorOptimizedParams( + Conv3dProblemSize const &problem_size, + Layout const &layout, + int element_size_bits, + MatrixCoord threadblock_shape, + int thread_count, + int access_size, + layout::PitchLinearCoord threadmap_iterations, + layout::PitchLinearCoord threadmap_delta + ): layout(layout) { + + TRACE_CONV_INITIALIZERS("conv3d_wgrad", "activation", + element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); + + // Precompute several quantities for fast modulo arithmetic. + RSC = problem_size.R * problem_size.S * problem_size.C; + find_divisor(rsc_mul, rsc_shr, RSC); + + SC = problem_size.S * problem_size.C; + find_divisor(sc_mul, sc_shr, SC); + + find_divisor(c_mul, c_shr, problem_size.C); + + ZPQ = problem_size.Z * problem_size.P * problem_size.Q; + find_divisor(zpq_mul, zpq_shr, ZPQ); + + PQ = problem_size.P * problem_size.Q; + find_divisor(pq_mul, pq_shr, PQ); + + find_divisor(q_mul, q_shr, problem_size.Q); + + } +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..97cad0a131667235fbab4c7dd092c1571ae3ee6c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -0,0 +1,289 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradActivationTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + + static int const kAccessesPerVector = 1; + + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // Filter postion (t,r,s,c) in contiguous dimension stays constant for each gemm_iteration_k + int filter_t_[ThreadMap::Iterations::kContiguous]; + int filter_r_[ThreadMap::Iterations::kContiguous]; + int filter_s_[ThreadMap::Iterations::kContiguous]; + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize t,r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); + int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); + + filter_r_[c] = residual / (problem_size_.S * problem_size_.C); + residual = residual % (problem_size_.S * problem_size_.C); + + filter_s_[c] = residual / problem_size_.C; + filter_c_[c] = residual % problem_size_.C; + + } + + // initialize n, z, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int t = filter_t_[iteration_contiguous_]; + int r = filter_r_[iteration_contiguous_]; + int s = filter_s_[iteration_contiguous_]; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + int z = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + int d = z * problem_size_.stride_d - problem_size_.pad_d + t * problem_size_.dilation_d; + int h = p * problem_size_.stride_h - problem_size_.pad_h + r * problem_size_.dilation_h; + int w = q * problem_size_.stride_w - problem_size_.pad_w + s * problem_size_.dilation_w; + + return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..7e5475f8f738e4f434f72e1f50a2c5762904cc42 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (activation tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradActivationTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params : Conv3dWgradActivationIteratorOptimizedParams { + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Conv3dWgradActivationIteratorOptimizedParams const &base) + : Conv3dWgradActivationIteratorOptimizedParams(base) {} + + CUTLASS_HOST_DEVICE + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradActivationIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + // Precomputed effective filter postion (t,r,s) in contiguous dimension stays constant for each gemm_iteration_k + // required for nzpq -> ndhw translation + int precomputed_filter_t_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_r_[ThreadMap::Iterations::kContiguous]; + int precomputed_filter_s_[ThreadMap::Iterations::kContiguous]; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int filter_c_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorOptimized( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize t,r,s,c filter position for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int trsc_offset = threadblock_offset.column() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // filter_t_[c] = trsc_offset / (problem_size_.R * problem_size_.S * problem_size_.C); + // int residual = trsc_offset % (problem_size_.R * problem_size_.S * problem_size_.C); + // + // filter_r_[c] = residual / (problem_size_.S * problem_size_.C); + // residual = residual % (problem_size_.S * problem_size_.C); + // + // filter_s_[c] = residual / problem_size_.C; + // filter_c_[c] = residual % problem_size_.C; + + int residual; + fast_divmod(precomputed_filter_t_[c], residual, trsc_offset, params_.RSC, params_.rsc_mul, params_.rsc_shr); + fast_divmod(precomputed_filter_r_[c], residual, residual, params_.SC, params_.sc_mul, params_.sc_shr); + fast_divmod(precomputed_filter_s_[c], filter_c_[c], residual, problem_size_.C, params_.c_mul, params_.c_shr); + + int t = precomputed_filter_t_[c]; + int r = precomputed_filter_r_[c]; + int s = precomputed_filter_s_[c]; + + if (problem_size_.mode == Mode::kConvolution) { + t = (problem_size_.T - 1 - t); + r = (problem_size_.R - 1 - r); + s = (problem_size_.S - 1 - s); + } + + // efective t,r,s for every contiguous dimension + precomputed_filter_t_[c] = - problem_size_.pad_d + t * problem_size_.dilation_d; + precomputed_filter_r_[c] = - problem_size_.pad_h + r * problem_size_.dilation_h; + precomputed_filter_s_[c] = - problem_size_.pad_w + s * problem_size_.dilation_w; + + + } + + // initialize n, z, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + offset_nzpq_[s] = threadblock_offset.row() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + } + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-B by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kRow * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the activation tensor x that is currently pointed to + /// by the iterator. + + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int n = offset_nzpq_[iteration_strided_] / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = offset_nzpq_[iteration_strided_] % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // int z = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, z, p, q; + fast_divmod(n, residual, offset_nzpq_[iteration_strided_], params_.ZPQ, params_.zpq_mul, params_.zpq_shr); + fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); + fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); + + int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; + int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_]; + int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; + + return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); + } + + /// Returns true if the current coordinate is within the activation tensor x + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() >= 0 && coord.d() < problem_size_.D && + coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && + coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradActivationTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h new file mode 100644 index 0000000000000000000000000000000000000000..cbe49985f5df8b76bfd1e57552e47577c379f229 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params { + + Layout layout; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + Conv3dProblemSize const &problem_size, + Layout const &layout + ): layout(layout) { + + } + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + int filter_k_[ThreadMap::Iterations::kContiguous]; + + int offset_nzpq_[ThreadMap::Iterations::kStrided]; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorAnalytic( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)) { + + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + // initialize filter_k for every contiguous iteration + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + filter_k_[c] = threadblock_offset.row() + thread_coord.contiguous() + + c * ThreadMap::Delta::kContiguous; + } + + // initialize n, p, q offset for every strided iteration + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] = threadblock_offset.column() + thread_coord.strided() + + s * ThreadMap::Delta::kStrided; + + } + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_nzpq_) in GEMM-A by a CTA-K tile + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_nzpq_[s] += Shape::kColumn * problem_size_.split_k_slices; + } + } + + /// Returns the coordinate in the output gradient tensor Dy that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int nzpq = offset_nzpq_[iteration_strided_]; + + int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + + int z = residual / (problem_size_.P * problem_size_.Q); + residual = residual % (problem_size_.P * problem_size_.Q); + + int p = residual / problem_size_.Q; + int q = residual % problem_size_.Q; + + return TensorCoord(n, z, p, q, filter_k_[iteration_contiguous_]); + } + + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && + coord.d() < problem_size_.Z && + coord.h() < problem_size_.P && + coord.w() < problem_size_.Q && + coord.c() < problem_size_.K; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + return reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorAnalytic &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..6c2f2e51e5e69f28d552839f906237b35d4879db --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (output gradient tile) + matrix from memory. + + This iterator assumes TensorNDHWC layout of tensors in Global Memory. + + The iterator is specialized for each of the three convolution operators: forward propagation (Fprop), + backward data gradient (Dgrad), and backward weight gradient (Wgrad). +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/conv/threadblock/conv3d_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, + typename Element_, + typename ThreadMap_ +> +class Conv3dWgradOutputGradientTileAccessIteratorOptimized { +public: + + // + // Types + // + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorNDHWC; + using ThreadMap = ThreadMap_; + using AccessType = AlignedArray; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 3; + using ConvProblemSize = typename conv::Conv3dProblemSize; + static int const kAccessesPerVector = 1; + static_assert(sizeof_bits::value >= 8, + "WGRAD requires elements of size 8b or greater."); + + // + // Parameters structure + // + + struct Params : Conv3dWgradOutputGradientIteratorOptimizedParams { + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Conv3dWgradOutputGradientIteratorOptimizedParams const &base) + : Conv3dWgradOutputGradientIteratorOptimizedParams(base) {} + + CUTLASS_HOST_DEVICE + Params(Conv3dProblemSize const &problem_size, Layout const &layout) + : Conv3dWgradOutputGradientIteratorOptimizedParams( + problem_size, + layout, + sizeof_bits::value, + {Shape::kRow, Shape::kColumn}, + ThreadMap::kThreads, + ThreadMap::kElementsPerAccess, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}) {} + }; + +private: + + Params const ¶ms_; + Conv3dProblemSize const &problem_size_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + char const *pointer_; + + uint32_t predicates_; + int filter_k_; + int offset_nzpq_; + +public: + + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorOptimized( + Params const ¶ms, + Conv3dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + predicates_(0), + filter_k_(0), + offset_nzpq_(0) { + + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.row() + thread_coord.contiguous(); + offset_nzpq_ = threadblock_offset.column() + thread_coord.strided(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int filter_k = filter_k_ + c * ThreadMap::Delta::kContiguous; + int offset_nzpq = offset_nzpq_ + s * ThreadMap::Delta::kStrided; + + bool predicate = valid_(at_(offset_nzpq, filter_k)); + + uint32_t pred = (predicate ? 1u : 0); + + int pred_idx = c + s * ThreadMap::Iterations::kContiguous; + + predicates_ |= (pred << pred_idx); + } + } + + // Offset pointer to (iteration_strided_, iteration_contiguous_) = (0, 0) + pointer_ += ( + offset_nzpq_ * params.layout.stride()[0] + filter_k_ + ) * sizeof_bits::value / 8; + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv3dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next GEMM-K offset (offset_npq_) in GEMM-A by a CTA-K tile + offset_nzpq_ += Shape::kColumn * problem_size_.split_k_slices; + + // Clear predicates if needed + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + if (offset_nzpq_ + s * ThreadMap::Delta::kStrided >= params_.NZPQ) { + uint32_t kClearMask = ((1u << ThreadMap::Iterations::kContiguous) - 1) << (s * ThreadMap::Iterations::kContiguous); + predicates_ = (predicates_ & (~kClearMask)); + } + } + pointer_ += params_.inc_next_nzpq; + } + +private: + /// Returns the coordinate in the output gradient tensor Dy that is (offset_nzpq, k) pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at_(int offset_nzpq, int k) const { + + // The subseqnet fast_divmod() operations are equivalent to the following logical computation: + // + // + // int nzpq = offset_nzpq_; + // int n = nzpq / (problem_size_.Z * problem_size_.P * problem_size_.Q); + // int residual = nzpq % (problem_size_.Z * problem_size_.P * problem_size_.Q); + // + // int z = residual / (problem_size_.P * problem_size_.Q); + // residual = residual % (problem_size_.P * problem_size_.Q); + // + // int p = residual / problem_size_.Q; + // int q = residual % problem_size_.Q; + + int residual, n, z, p, q; + fast_divmod(n, residual, offset_nzpq, params_.ZPQ, params_.zpq_mul, params_.zpq_shr); + fast_divmod(z, residual, residual, params_.PQ, params_.pq_mul, params_.pq_shr); + fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); + + return TensorCoord(n, z, p, q, k); + } + + /// Returns true if the coord is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid_(TensorCoord coord) const { + + return coord.n() < problem_size_.N && + coord.c() < problem_size_.K; + } + +public: + + /// Returns true if the current coordinate is within the output gradient tensor Dy + CUTLASS_HOST_DEVICE + bool valid() const { + + LongIndex pred_idx = iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous; + return (predicates_ & (1u << pred_idx)); + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + + return reinterpret_cast( + pointer_ + + iteration_strided_ * params_.offset_next_strided + + iteration_contiguous_ * params_.offset_next_contiguous + ); + + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + Conv3dWgradOutputGradientTileAccessIteratorOptimized &operator++() { + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv3dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h new file mode 100644 index 0000000000000000000000000000000000000000..f5cd2a740232c257f8e3b25c37408973f536722b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvParams; + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvFilterIteratorParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template<> +struct Depthwise2dFpropDirectConvParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int32_t activation_tile_h; + int32_t activation_tile_w; + int32_t activation_tile_hw; + FastDivmod activation_tile_w_divmod; + + int filter[2]; + int stride[2]; + int dilation[2]; + int inc_next[2]; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_load_count; + int activation_storage_elements; + int activation_size; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams() { } + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + MatrixCoord threadblock_shape, ///< CTA threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int element_size_bits, ///< bits of activation element + const int thread_count, ///< threads per threadblock + const int thread_count_contiguous, ///< number of threads for continuous dimension + const int element_per_load) ///< element per each load + : layout(layout) { + + filter[0] = problem_size.S; + filter[1] = problem_size.R; + + stride[0] = problem_size.stride_w; + stride[1] = problem_size.stride_h; + + dilation[0] = problem_size.dilation_w; + dilation[1] = problem_size.dilation_h; + + // Compute activation_tile size per threadblock because stride and dilation are runtime params. + activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + + (problem_size.R - 1) * problem_size.dilation_h + 1; + activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + + (problem_size.S - 1) * problem_size.dilation_w + 1; + activation_tile_hw = activation_tile_h * activation_tile_w; + + activation_tile_w_divmod = FastDivmod(activation_tile_w); + + /// Below two values could not be templatized because the stride and dilation are runtime params + activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; + activation_storage_elements = activation_load_count * element_per_load * thread_count; + activation_size = activation_storage_elements * element_size_bits / 8; + + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / + (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + + // next S + inc_next[0] = problem_size.dilation_w; + // next R + inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template <> +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int activation_size_ ///< Activation size loaded by iterator + ) + : layout(layout), + activation_size(activation_size_) { + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = + (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template <> +struct Depthwise2dFpropDirectConvFilterIteratorParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + int filter_size; + + bool is_convolution; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + const int filter_size_) ///< Filter size loaded by iterator + : layout(layout), + filter_size(filter_size_), + is_convolution(problem_size.mode == Mode::kConvolution){} +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h new file mode 100644 index 0000000000000000000000000000000000000000..012e306d800c3bcd62c1322a217d643d2ae38fd5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h @@ -0,0 +1,314 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + // Compilation value of stride , dialtion and activation shape + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ActivationShape = ActivationShape_; + + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // activation N is inside the Tensor or not + bool valid_n_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0) { + + base_c_ = threadblock_offset.column(); + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + + valid_n_ = activatioin_base_.n() < problem_size_.N; + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + kActivationSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h = next / ActivationShape::kW; + int w = next % ActivationShape::kW; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + bool valid_c = coord.c() < problem_size_.C; + bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; + bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; + return valid_n_ ? valid_c & valid_h & valid_w : 0; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < ThreadMap::Iterations::kCount) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kActivationSize; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return ThreadMap::Iterations::kCount; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check stride and dilation constraint + if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..b8ae9b9312c79f88715fd8cd1efebb2dad8a76f1 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // Number of loads for activations tensor X. + const int number_of_loads_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0), + number_of_loads_(params.activation_load_count) { + + base_c_ = threadblock_offset.column(); + + set_activation_coord(offset_intial_npq_); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + sizeof_bits::value, + ThreadMap::kThreads, + ThreadMap::Detail::ShapeVec::kContiguous, + ThreadMap::kElementsPerAccess); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h, w; + params_.activation_tile_w_divmod(h, w, next) ; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < number_of_loads_) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return params_.activation_size; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return number_of_loads_; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..846f1f3aeb269edc67b5e2c02db3f05993172025 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Epilogue stores the data into global memory + typename Epilogue_, + /// iterator implementation variants + conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseFpropDirectConvMultipleStage : + public DepthwiseDirectConvMmaBase { +public: + ///< Base class + using Base = DepthwiseDirectConvMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using Epilogue = Epilogue_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropDirectConvMultipleStage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + // Number of iterators is a static value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA &iterator_A, + ///< Params of global memory iterator + typename IteratorA::Params const &iterator_a_params, + ///< iterator over B operand in global memory + IteratorB &iterator_B, + ///< Params of global memory iterator + typename IteratorB::Params const &iterator_b_params, + ///< initial value of accumulator + FragmentC const &src_accum, + /// Epilogue + Epilogue &epilogue, + ///< Output operator + typename Epilogue::OutputOp const &output_op, + ///< Tile iterator for destination + typename Epilogue::OutputTileIterator &destination_iterator, + ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + typename Epilogue::OutputTileIterator &source_iterator, + + int split_k_slices = 1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + if (stage == 0) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + } + + if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ + // Number of iterators is compilation static. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); + this->smem_iterator_A_.set_iteration_index(0); + + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + // Move to the next stage + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + unsigned int iterations = 0; + constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. + + accum.clear(); + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { + if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k == 0) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k < Base::kWarpGemmIterations) { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == inner_loop_iterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == inner_loop_iterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next cta + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); + + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); + smem_read_stage_idx = 0; + } else { + this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); + ++smem_read_stage_idx; + } + + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + } + + // goback to start position. B has no multiple stage + this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); + + --gemm_k_iterations; + } + } + + // + // Epilogue + // + int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); + + destination_iterator.set_tile_index(iterations * split_k_slices); + + source_iterator.set_tile_index(iterations * split_k_slices); + + epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); + + ++iterations; + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 0000000000000000000000000000000000000000..1035fda375787cc211929854b662d1ccb7a809ae --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +template > +class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { +public: + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + using Params = Depthwise2dFpropDirectConvFilterIteratorParams; + + protected: + + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int offset_trs_[ThreadMap::Iterations::kStrided]; + +public: + + + + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Do nothing because the filter is persistent in the SMEM + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + int trs = offset_trs_[iteration_strided_]; + + return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && + coord.h() < Shape::kColumn; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + int64_t offset = coord.n(); + if (params_.is_convolution) { + offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; + } else { + offset += coord.h() * problem_size_.K; + } + + return reinterpret_cast(pointer_ + + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines the filter size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kFilterSize; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // check whether runtime filter size is same as templated filter size. + if ((problem_size.R * problem_size.S) != Shape::kColumn) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..30d13e9087e4b47384a43b9381036f668d581808 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + // Depthwise specific + int channel_start_index = 0; + int rs_plane_idx = 0; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Reset interation index. + iterator_B.set_iteration_index(0); + } + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Move to next set of filter groups. + channel_start_index += Base::kWarpGemmIterations; + } + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + + rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..44dafcb5fa4f099e8070a9c8d271c4048128ceac --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_base.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a directconv threadblock-scoped Depthwise kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy object describing MmaTensorOp +template < + /// Warp-level GEMM operator (concept: gemm::warp::Mma) + typename Operator_, + /// Padding used for A operand in shared memory (concept: MatrixShape) + typename SmemPaddingA_, + /// Padding used for B operand in shared memory (concept: MatrixShape) + typename SmemPaddingB_, + /// + typename ThreadMapA_, + /// + typename ThreadMapB_, + /// Number of partitions of K dimension of GEMM + int PartitionsK = 1> +struct DepthwiseDirectConvMmaPolicy { + /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) + using Operator = Operator_; + + /// Padding used for A operand in shared memory + using SmemPaddingA = SmemPaddingA_; + + /// Padding used for B operand in shared memory + using SmemPaddingB = SmemPaddingB_; + + using ThreadMapA = ThreadMapA_; + using ThreadMapB = ThreadMapB_; + + /// Number of partitions of K dimension + static int const kPartitionsK = PartitionsK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseDirectConvMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm:: + GemmShape; + + /// Number of warp-level GEMM oeprations + /// kWarpGemmIterations could be even and odd. + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape<1, // Not determined at compile-time :( + Shape::kN + Policy::SmemPaddingA::kRow>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; // Tile N = 64? + + public: + // + // Data members + // + + // Let persistent B matrix in front of dynamic matrix A + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand + /// Not be determined at compile-time -- Just to get a Smem start address. + AlignedBuffer operand_A; + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseDirectConvMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h new file mode 100644 index 0000000000000000000000000000000000000000..9e3cc417d4cc3169724f9e5db9e82fa093121fae --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -0,0 +1,952 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data + layout of the global memory fragments, data types, and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/warp/mma_depthwise_simt.h" + +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_singlestage.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" + +#include "cutlass/arch/cache_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +namespace detail { +// +// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by +// each partitions within warp. +// The goal is for each thread's tile of elements to be as square as +// possible for performance (4x4 will be faster than 2x8). +template // The number of partitions within the warp +struct SimtWarpShape { + // kP * kQ * WarpNumThreadsM = WarpShapeM + // If needed, enable more specializations. +}; +template <> +struct SimtWarpShape<4, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<8, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<8, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; +template <> +struct SimtWarpShape<8, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<16, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template +struct SimtWarpShape<25, WarpNumThreadsM> { + static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); + static constexpr int kP = 5; + static constexpr int kQ = 5; +}; + +template <> +struct SimtWarpShape<32, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 8; +}; + +template <> +struct SimtWarpShape<32, 2> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; + +template <> +struct SimtWarpShape<32, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; + +} // namespace detail + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of threadblock-scoped output tile + typename ThreadBlockOutputShape, + /// Shape of filter shape per threadblock + typename FilterShape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA, + /// per-element transformation for elements of B + ComplexTransform TransformB, + bool IsComplex +> +struct DepthwiseMmaCoreWithLaneAccessSize< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> : cutlass::gemm::threadblock::DefaultMmaCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeA_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeB_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + kLaneAccessSizeB_, + 2, + Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_> { + using Base = cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_>; + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeA = kLaneAccessSizeA_; + static int const kLaneAccessSizeB = kLaneAccessSizeB_; + + // Divisility requirements + static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = typename Base::WarpCount; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory are same as base class + // + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); + static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + true // Dynamic iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + // Get output P, Q per thread + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape_, + DilationShape_, + ActivationShape_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using ActivationShape = ActivationShape_; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + false // static iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type + StrideShape, /// Stride ( MatrixShape ) + DilationShape, /// Dilation ( MatrixShape ) + ActivationShape /// Activation Shape loaded by threadblock + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..482a52fe63209650546811aa24cafcc7419e7479 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped fused activation's + scale+bias+relu and Implicit GEMM Convolution kernel. + + The original implicit gemm will store out-of-bound data as zeroes in the + shared memory because zeros into the tensor core, zeroes out of the tensor + cores. The result is remained the same. When fusing scale+bias+relu + into the mainloop, it is no longer true because + + 0 x scale + bias = bias + + which is no longer always 0. So, instead of storing zeroes, this fused + kernel stores the out-of-bound data as a special NaN (0x7eff), when applying + scale+bias+relu, the code is like + + if (data == 0x7eff) + data = 0; + else + data = scale+bias+relu(data, scale, bias); + + See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the + elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/conv/warp/scale_bias_relu_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Element type of scale and bias vectors + typename ElementScaleBias_, + /// Layout of scale and bias vectors + typename LayoutScaleBias_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorScaleBias_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaFpropFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Element type of scale and bias vectors + using ElementScaleBias = ElementScaleBias_; + + /// Layout of scale and bias vectors + using LayoutScaleBias = LayoutScaleBias_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorScaleBias = WarpIteratorScaleBias_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the scale and bias vectors + using TensorRefScaleBias = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the A scale and bias vectors in shared memory + using ShapeScaleBias = + MatrixShape<1 + Policy::SmemPaddingA::kRow, + 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand Scale and Bias + AlignedBuffer operand_A_scale_bias; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a layout object for the A scale and bias vectors + CUTLASS_DEVICE + static LayoutScaleBias LayoutScaleBias() { + return LayoutScaleBias::packed( + {ShapeScaleBias::kRow, ShapeScaleBias::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the A operand Scale vector + CUTLASS_HOST_DEVICE + TensorRefScaleBias operand_A_scale_bias_ref() { + return TensorRefScaleBias{operand_A_scale_bias.data(), LayoutScaleBias()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A operand scale and bias vector + /// from shared memory + WarpIteratorScaleBias warp_tile_iterator_A_scale_bias_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaFpropFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_A_scale_bias_( + shared_storage.operand_A_scale_bias_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorScaleBias_, + /// Iterates over vectors of scale and bias vector in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorScaleBias_, + /// Cache operation for scale/bias operand + cutlass::arch::CacheOperation::Kind CacheOpScaleBias, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorScaleBias_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmFpropFusionMultistage + : public MmaFpropFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorScaleBias = IteratorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorScaleBias = WarpIteratorScaleBias_; + ///< Policy describing tuning details + using Policy = Policy_; + ///< Base class + using Base = MmaFpropFusionBase; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScaleBias = SmemIteratorScaleBias_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpScaleBias = + CacheOpScaleBias; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpLoadedFragmentScaleBias = + typename WarpIteratorScaleBias::Fragment; + + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory + SmemIteratorScaleBias smem_iterator_A_scale_bias_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmFpropFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_A_scale_bias_(shared_storage.operand_A_scale_bias_ref(), + thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_bias_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorScaleBias &iterator_A_scale_bias, + IteratorB &iterator_B, int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + // Uses nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + + ++this->smem_iterator_A_; + } + } + + // Async Copy for operand A scale and bias vector. Scale and bias vectors + // are small. One iteration is enough. + if (group_start_A == 0) { + typename IteratorScaleBias::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_scale_bias_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorScaleBias::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); + } + + iterator_B.set_iteration_index(group_start_B); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale and bias vectors in global memory + IteratorScaleBias iterator_A_scale_bias, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + // Uses Nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + ++this->smem_iterator_A_; + } + + // Async Copy for operand A scale and bias vectors. Scale and bias + // vectors are small. One iteration is enough. + { + typename IteratorScaleBias::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_scale_bias_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorScaleBias::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_scale_bias.get(), iterator_A_scale_bias.valid()); + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_A_scale_bias.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpLoadedFragmentScaleBias warp_loaded_frag_A_scale_bias[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::conv::warp::FpropScaleBiasReluTransform + elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_scale_bias_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_A_scale_bias_.load( + warp_loaded_frag_A_scale_bias[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_bias_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + elementwise_transform(warp_transformed_frag_A[0], + warp_loaded_frag_A_scale_bias[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_scale_bias_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_bias_.load( + warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_bias_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_loaded_frag_A_scale_bias[warp_mma_k % 2]); + } + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_A_scale_bias, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + elementwise_transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_A_scale_bias[(warp_mma_k + 1) % 2]); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_A_scale_bias.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_scale_bias_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_A_scale_bias_.add_tile_offset( + {0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_A_scale_bias_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..6c9c4792e289824afd1a761f5b7b4cc5972f167a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -0,0 +1,539 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmMultistage : + public gemm::threadblock::MmaBase { +public: + ///< Base class + using Base = gemm::threadblock::MmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA &iterator_A, IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (Detail::kStagedAccumulation) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, + group_start_iteration_B); + + if (Detail::kStagedAccumulation) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum + ); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + if (Detail::kStagedAccumulation) { + accum = plus_accum(accum, tmp_accum); + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..45e27949665f797ba28afcd5f1cf98007c56eac9 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class ImplicitGemmPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..3be08c1ad90cf896b0b2191aa0c0a4a5a8c5b033 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h @@ -0,0 +1,729 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped fused activation's scale+bias+relu and + Implicit GEMM Convolution kernel. + + The original implicit gemm will store out-of-bound data as zeroes in the + shared memory because zeros into the tensor core, zeroes out of the tensor + cores. The result is remained the same. When fusing scale+bias+relu + into the mainloop, it is no longer true because + + 0 x scale + bias = bias + + which is no longer always 0. So, instead of storing zeroes, this fused + kernel stores the out-of-bound data as a special NaN (0x7eff), when applying + scale+bias+relu, the code is like + + if (data == 0x7eff) + data = 0; + else + data = scale+bias+relu(data, scale, bias); + + The biggest difference compared with the fused Fprop and scale+bias+relu is + that scale and bias are loop invariant in Wgrad so that they only needs to + be loaded once before the mainloop. + + See include/cutlass/conv/warp/scale_bias_relu_transformation.h for the + elementwise computation. See include/cutlass/arch/memory_sm80.h for nan fill. + + +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/conv/warp/scale_bias_relu_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Element type of scale and bias vectors + typename ElementScaleBias_, + /// Layout of scale and bias vectors + typename LayoutScaleBias_, + /// Element type of scale and bias vectors + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaWgradFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Element type of scale and bias vectors + using ElementScaleBias = ElementScaleBias_; + + /// Layout of scale and bias vectors + using LayoutScaleBias = LayoutScaleBias_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaWgradFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorScaleBias_, + /// Iterates over vectors of scale and bias vector i + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class ImplicitGemmWgradFusionMultistage + : public MmaWgradFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorScaleBias = IteratorScaleBias_; + ///< Policy describing tuning details + using Policy = Policy_; + ///< Base class + using Base = MmaWgradFusionBase; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const kBBufferSize = + ((sizeof(typename Operator::ElementC) == 4) && + ((platform::is_same::value && + platform::is_same::value)) && + (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) + ? 1 + : 2; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpLoadedFragmentScaleBias = typename IteratorScaleBias::Fragment; + + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + int warp_idx_m_; + + int warp_idx_n_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + ImplicitGemmWgradFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; + warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + + iterator_A.set_iteration_index(group_start_A); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B); + + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + // Uses nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale and bias vectors in global memory + IteratorScaleBias iterator_B_scale_bias, + ///< initial value of accumulator + FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + WarpLoadedFragmentScaleBias warp_loaded_frag_B_scale_bias; + iterator_B_scale_bias.add_tile_offset({0, warp_idx_n_}); + iterator_B_scale_bias.load(warp_loaded_frag_B_scale_bias); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / 8; + + // Uses Nan fill for out of bound data + cutlass::arch::cp_async_nan( + dst_ptr, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[Detail::kBBufferSize]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[Detail::kBBufferSize]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::conv::warp::WgradScaleBiasReluTransform + elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + elementwise_transform(warp_transformed_frag_B[0], + warp_loaded_frag_B_scale_bias); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (Detail::kBBufferSize == 2) { + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize]); + ++this->warp_tile_iterator_A_; + } + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_loaded_frag_B[warp_mma_k % 2]); + + elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_B_scale_bias); + } + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % Detail::kBBufferSize], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + if (Detail::kBBufferSize == 1) { + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + ++this->warp_tile_iterator_A_; + + } + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % Detail::kBBufferSize], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + elementwise_transform( + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_B_scale_bias); + } + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.advance(); + iterator_B.advance(); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..dac642385cd445e9a36a2f4c6f6c9e51f309cb87 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorAccessIterator +/// +template +class PredicatedScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; + + using AccessType = AlignedArray; + + using Params = PredicatedScaleBiasVectorAccessIteratorParams; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + int problem_size_trs; + int problem_size_c; + int filter_trs_; + + TensorCoord thread_offset_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.T * problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + thread_offset_ = + thread_offset_ + + TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + (thread_offset_.contiguous() * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + // moves to the next tile + ++filter_trs_; + if (filter_trs_ == problem_size_trs) { + filter_trs_ = 0; + add_tile_offset(TensorCoord(1, 0)); + } + } + + /// Increment and return an instance to self. + CUTLASS_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + uint32_t enabled = 0; + +#if defined(_MSC_VER) || (__CUDACC_VER_MAJOR__ < 11) + enabled = threadIdx.x < kThreads * 2; +#else + asm volatile( + "{\n" + " .reg .u32 tid_reg;\n" + " .reg .pred p;\n" + " mov.u32 tid_reg, %%tid.x;\n" + " setp.lt.u32 p, tid_reg, %1;\n" + " selp.u32 %0, 1, 0, p;\n" + "}\n" : "+r"(enabled) :"n"(kThreads * 2)); +#endif + + return ((thread_offset_.contiguous() < problem_size_c) && enabled); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + + using Params = PredicatedScaleBiasVectorAccessIteratorParams; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv2dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv3dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv2dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv3dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + iterator_.advance(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e9844be9f000920fd82f18dc6dab5755611f08ea --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/predicated_scale_bias_vector_iterator.h @@ -0,0 +1,371 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorIterator +/// +template +class PredicatedScaleBiasVectorIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 1; + + using AccessType = AlignedArray; + + static int const kIterations = WarpShape::kContiguous / 8; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; + + /// Parameters object is precomputed state and is host-constructible + using Params = Conv2dWgradActivationIteratorOptimizedParams; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + ConstPointer scale_pointer_; + ConstPointer bias_pointer_; + + /// Size of tensor + Conv2dProblemSize problem_size_; + + int32_t thread_offset_; + + // Channel dimension in contiguous dimension stays constant for each gemm_iteration_k + int32_t filter_c_[kIterations]; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_(problem_size), + scale_pointer_(scale_pointer), + bias_pointer_(bias_pointer) { + + thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; + } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv2dProblemSize const &problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); + + CUTLASS_PRAGMA_UNROLL + for(int c = 0; c < kIterations; ++c) { + int rsc_offset = thread_offset_ + c * 8; + + int residual, tmp; + params_.sc_divmod(tmp, residual, rsc_offset); + params_.c_divmod(tmp, filter_c_[c], residual); + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.fill(__float2half2_rn(0.0f)); + __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); + + // load scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2].x, + scale_pointer_ + filter_c_[c], + true + ); + } + + // load bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2 + 1].x, + bias_pointer_ + filter_c_[c], + true + ); + } + + // duplicate scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2].y = frag_ptr[c * 2].x; + } + + // duplicate bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + using Fragment = typename UnderlyingIterator::Fragment; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedScaleBiasVectorIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Conv2dProblemSize const &problem_size, Layout const &layout) + : params_(problem_size, layout::TensorNHWC(0, 0, 0)){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Extent of tensor + Conv2dProblemSize const &problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + Params const ¶ms, ///< Precomputed parameters object + Conv2dProblemSize const &problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorIterator(params, problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + iterator_.load(frag); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..0c5aed6dba0fa206fcab9545eeeb165558cb724a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/threadblock/threadblock_swizzle.h @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several possible threadblock-swizzling functions mapping blockIdx to + Convolution problems. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_HOST_DEVICE +static int get_strided_dgrad_tile_m( + cutlass::conv::Conv2dProblemSize const &problem_size, + int tile_size_m) { + + // CTAs in M dimension per starting filter position + int tile_m_per_filter = strided_dgrad_tile_m_per_filter(problem_size, tile_size_m); + + // Inflate number of CTAs in M dimension to cover every strating filter position even those that + // may fall out of valid MMA (Dy * w) but are needed to apply epilogue (beta * Dx_source) + // and point-wise fusion + int tile_m = tile_m_per_filter * int(problem_size.stride().product()); + + // There is a possible performance optimization here that leads up to 2x speeds than the current + // CUTLASS strided dgrad performance for stride > filter, i.e., stride={2x2} and filter={1x1}) + // + // * Optimization * + // Only launch CTAs in M dimension which contribute to a row in Dx output + // + // + // * Constraints * + // (A) stride <= filter, for example, stride={2x2} and filter={3x3}: + // - (A.1): There are no constraints for this case and the optimization does + // affect this case functionality or performance. + // (B) stride > filter, for example, stride={2x2} and filter={1x1}: + // - (B.1): Dx output tensor should be zero initialized + // - (B.2): The kernel epilogue cannot apply beta. Thus, beta should be zero + + return tile_m; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +struct StridedDgradHorizontalThreadblockSwizzle : + public gemm::threadblock::GemmHorizontalThreadblockSwizzle { + + using Base = gemm::threadblock::GemmHorizontalThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradHorizontalThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimension + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for strided dgrad convolution +template +struct StridedDgradIdentityThreadblockSwizzle : + public gemm::threadblock::GemmIdentityThreadblockSwizzle { + + using Base = gemm::threadblock::GemmIdentityThreadblockSwizzle; + + CUTLASS_HOST_DEVICE + StridedDgradIdentityThreadblockSwizzle() { } + + /// Returns the shape of the problem in units of logical tiles + /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + // compute number of tiles in m dimension + int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); + + // compute number of tiles in n dimension + int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); + + return gemm::GemmCoord( + tile_m, + tile_n, + split_k_slices); + } + + /// Returns the shape of the problem in units of logical tiles + /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) + private: + using Base::get_tiled_shape; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Threadblock swizzling function for GEMMs +template +struct DepthwiseDirect2dConvIdentityThreadblockSwizzle + : public gemm::threadblock::GemmIdentityThreadblockSwizzle { + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} + + /// Returns the shape of the problem in units of logical tiles + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return gemm::GemmCoord(1, + (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), + split_k_slices); + } +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h new file mode 100644 index 0000000000000000000000000000000000000000..b7af2e37bd610a12f334943902395b6956362589 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt.h @@ -0,0 +1,380 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/thread/depthwise_mma.h" + + +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" + +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseSimt + : public cutlass::gemm::warp:: + MmaSimt { + using Base = cutlass::gemm::warp:: + MmaSimt; + +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + +public: + + /// Iterates over the B operand in memory + using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseSimt():Base() {} +}; + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + typename FilterShape_, + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + typename ThreadOutputShape_, + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape_ = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape_ = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseDirectConvSimt { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + using FilterShape = FilterShape_; + + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Iterator algo type + static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || + platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && + platform::is_same< ElementA, int8_t >::value && + platform::is_same< ElementB, int8_t >::value; + + using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; + + /// Thread-level matrix multiply accumulate operator + using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< + cutlass::gemm::GemmShape< + Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread + Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread + 1>, + ElementA, + ElementB, + ElementC, + arch::OpMultiplyAdd, + dp4a_type + >; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Shape of the underlying instruction + using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; + +public: + + /// Iterates over the A operand in memory + using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< + MatrixShape, // per warp + FilterShape, + ThreadOutputShape, + ThreadBlockOutputShape, + cutlass::gemm::Operand::kA, + ElementA, + Policy, + IteratorAlgorithm, + StrideShape, + DilationShape, + ActivationShape, + PartitionsK, + Shape::kK + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape<1, Shape::kN>, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + + /// Iterates over the C operand in memory + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kC, + ElementC, + LayoutC, + Policy + >; + + /// Storage for C tile + using FragmentC = typename ThreadMma::FragmentC; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseDirectConvSimt() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &d, + FragmentA a, + FragmentB b, + FragmentC const &c, int group_idx = 0) const { + + ThreadMma mma; + + mma(d, a, b, c); + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..47fd1e08b9ff9f693b462fd89f5230475d918120 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -0,0 +1,862 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT + instructions +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/conv/convolution.h" + +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions +/// +/// concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1 +> +class DepthwiseMmaSimtTileIterator; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization for B operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseMmaSimtTileIterator + : public cutlass::gemm::warp::MmaSimtTileIterator { + + using Base = cutlass::gemm::warp::MmaSimtTileIterator; + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename Base::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Thread-level shape of a fragment + using ThreadShape = typename Base::ThreadShape; + + /// Number of individual loads + using Iterations = typename Base::Iterations; + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + + static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); + +private: + + MatrixCoord lane_offset_; + int channel_idx_; + int base_channel_idx_; + int warps_n_; + + public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator():Base() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator( + TensorRef ref, + int lane_id + ) : Base(ref, lane_id) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + warps_n_ = -1; + channel_idx_ = 0; + base_channel_idx_ = 0; + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + + if(warps_n_ == -1){ + warps_n_ = coord.column(); + } + + Base::add_tile_offset(coord); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < Iterations::kRow; ++k) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + + void const *ptr = this->ref_.data() + + this->ref_.offset({-(channel_idx_ - base_channel_idx_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + + // Base_k of a warp + Base_k of current threads. + int thread_k_base_idx = + warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); + + if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { + // Depthwise kernel would only do computation when channel == k. + // Loads an element when the current computation channel == the k corresponding to this thread. + arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); + } else { + // Reduce SMEM load + dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + if(k_group % PartitionGroupSize == 0 && k_group != 0){ + base_channel_idx_ = k_group; + } + channel_idx_ = k_group; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadBlockOutputShape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1> +class DepthwiseDirect2dConvSimtTileIterator; + + +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm, + /// Stride ( MatrixShape ) + typename StrideShape, + /// Dilation ( MatrixShape ) + typename DilationShape, + /// Activation Shape loaded by threadblock + typename ActivationShape, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + +// Thread-level shape of a fragment + using ThreadShape = MatrixShape< + ThreadOutputShape::kNHW, // Output tile shape Computed by current threads + ThreadOutputShape::kC + >; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = MatrixShape< + ThreadShape::kRow, + ThreadShape::kColumn / Policy::LaneMmaShape::kN + >; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + +protected: + + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + int iterator_offset_; + + int inc_next_s_ ; + int inc_next_r_ ; + + MatrixCoord lane_offset_; +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator( + TensorRef ref, + int lane_id + ) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + iterator_offset_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE + void setup_initial_status(Params const& params) { + + inc_next_s_ = params.inc_next[0]; + inc_next_r_ = params.inc_next[1]; + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_p_ = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + int base_q_ = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int base_w = (base_q_ + q) * params.stride[0]; + int base_h = (base_p_ + p) * params.stride[1]; + + int offset = base_h * params.activation_tile_w + base_w; + activation_offset[p][q][col] = offset; + } + } + } + } + + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + iterator_offset_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + iterator_offset_ += inc_next_s_; + + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + iterator_offset_ += inc_next_r_; + return *this; + } + + iterator_r_ = 0; + iterator_offset_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator & operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + Array *dst_ptr = + reinterpret_cast *>(&frag); + + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + void const *ptr = ref_.data() + + ref_.offset({activation_offset[p][q][n] + (iterator_offset_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged " + "along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, + "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + + // Activations loaded by threadblock + static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ThreadActivationShape = cutlass::conv:: + TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; + + // Thread-level shape of a fragment + using ThreadShape = + MatrixShape; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = + MatrixShape; + + using ThreadTileCount = MatrixShape; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + + protected: + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + Array + activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + + + MatrixCoord lane_offset_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE void setup_initial_status( + Params const ¶ms) { + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_h = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; + int base_w = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadActivationShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadActivationShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (base_h + h) * ActivationShape::kW + (base_w + w); + + void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); + arch::shared_load(activation[h][w][col], ptr); + } + } + } + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = + MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + return *this; + } + + iterator_r_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; + const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; + + dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..6cb3935a7e070f0dc34b1ec9c31d9ac448d43b8b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/conv/warp/scale_bias_relu_transform.h @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level per channel scale+bias+relu before + matrix multiply-accumulate operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FpropScaleBiasReluTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumScaleBias = FragmentScaleBias::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 columns + static int const MmaCols = 2; + + using MmaOperand = Array; + using ScaleBiasOperand = Array; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t *ptr_activations = reinterpret_cast(&activations); + uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + + // We assumes the pair of FP16 are either both inbound or both out-of-bound. + // It requires C to be an even number. + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .b32 t1;\n\t" + " setp.eq.u32 %%p, %2, %4;\n\t" + " fma.rn.f16x2.relu t1, %1, %2, %3;\n" + " selp.u32 %0, 0, t1, %%p;\n\t" + "}\n" + : "=r"(ptr_activations[0]) + : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); +#else + assert(0); +#endif + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentScaleBias const &scale_bias) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + ScaleBiasOperand const *ptr_scale_bias = + reinterpret_cast(&scale_bias); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], ptr_scale_bias[(i / MmaScaleBiasPair) % MmaCols]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WgradScaleBiasReluTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumScaleBias = FragmentScaleBias::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 rows + static int const MmaRows = 2; + + using MmaOperand = Array; + using ScaleBiasOperand = Array<__half2, MmaScaleBiasPair>; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, ScaleBiasOperand const &scale_bias) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + __half2 *ptr_activations = reinterpret_cast<__half2 *>(&activations); + uint32_t const *ptr_scale_bias = reinterpret_cast(&scale_bias); + +#if 1 + // CUDA + PTX version + + bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); + bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + + // We cannot gurantee that the pair of F16 are both in bound or both + // out-of-bound because C x R x S can be an odd number. + asm volatile( + "{\n\t" + " fma.rn.f16x2.relu %0, %1, %2, %3;\n" + "}" + : "=r"(reinterpret_cast(ptr_activations[0])) + : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), + "r"(ptr_scale_bias[1])); + + reinterpret_cast(ptr_activations[0]) = h1_oob ? + (reinterpret_cast(ptr_activations[0]) & 0xffff0000) : + reinterpret_cast(ptr_activations[0]); + + reinterpret_cast(ptr_activations[0]) = h2_oob ? + (reinterpret_cast(ptr_activations[0]) & 0xffff) : + reinterpret_cast(ptr_activations[0]); +#else + // pure PTX version + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + asm volatile( + "{\n" + " .reg .b16 t1, t2;\n" + " .reg .b32 t3, t4, t5, t6;\n" + " .reg .pred p1, p2;\n" + " mov.b32 {t1, t2}, %2;\n" + " setp.eq.s16 p1, t1, %4;\n" + " setp.eq.s16 p2, t2, %4;\n" + " fma.rn.f16x2.relu t3, %1, %2, %3;\n" + " and.b32 t4, t3, %5;\n" + " selp.b32 t5, t4, t3, p1;\n" + " and.b32 t6, t5, %6;\n" + " selp.b32 %0, t6, t5, p2;\n" + "}\n" + : "=r"(reinterpret_cast(ptr_activations[0])) + : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); +#endif +#else + assert(0); +#endif + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentScaleBias const &scale_bias) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + ScaleBiasOperand const *ptr_scale_bias = + reinterpret_cast(&scale_bias); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], ptr_scale_bias[(i / MmaRows)]); + } + } +}; +} // namespace warp +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h new file mode 100644 index 0000000000000000000000000000000000000000..16cfa1b322f24f3e1c64f14b91dd880798e3b68d --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/coord.h @@ -0,0 +1,478 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#endif + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Statically-sized array specifying Coords within a tensor +template < + int Rank_, ///< Logical rank of coordinate + typename Index_ = int, ///< Index type used for each dimension + typename LongIndex_ = int64_t ///< Long index type used for linear offsets +> +struct Coord { + +public: + + // + // Type and constant definitions + // + + /// Number of elements in Coord + static int const kRank = Rank_; + + /// Index type used to store elements + using Index = Index_; + + /// Type used to represent linear offsets + using LongIndex = LongIndex_; + +private: + + // + // Data members + // + + /// Indices + Index idx[kRank]; + +public: + + // + // Methods + // + + /// Default ctor initializes uniformly + CUTLASS_HOST_DEVICE + explicit Coord(Index value = Index(0)) { + for (int i = 0; i < kRank; ++i) { + idx[i] = value; + } + } + + /// Constructs from an array of integers + CUTLASS_HOST_DEVICE + Coord(Index const (&_idx)[kRank]) { + for (int i = 0; i < kRank; ++i) { + idx[i] = _idx[i]; + } + } + + /// Constructs from some other Coord + template + CUTLASS_HOST_DEVICE + Coord(Coord other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + + /// Returns a slice of the Coord which may be larger or smaller in rank + /// than this. + template + CUTLASS_HOST_DEVICE + Coord slice(int start = 0, Index identity = 0) const { + Coord result; + for (int i = 0; i < Slice; ++i) { + if (i + start < kRank) { + result[i] = idx[i + start]; + } + else { + result[i] = identity; + } + } + return result; + } + + /// Returns the index of the dimension with least value + CUTLASS_HOST_DEVICE + int min_dim_index() const { + int i = 0; + for (int j = 1; j < kRank; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + /// Returns the index of the dimension with greatest value + CUTLASS_HOST_DEVICE + int max_dim_index() const { + int i = 0; + for (int j = 1; j < kRank; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; + } + + /// Returns true if Coord is non-zero. + CUTLASS_HOST_DEVICE + explicit operator bool() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + /// Returns true if Coord is uniformly zero. + CUTLASS_HOST_DEVICE + bool operator!() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Coord operator+(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Coord operator-(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Coord operator*(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Coord operator/(Coord const& b) const { + Coord c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Coord& operator+=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Coord& operator-=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] -= b.idx[i]; + } + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Coord& operator*=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] *= b.idx[i]; + } + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Coord& operator/=(Coord const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] /= b.idx[i]; + } + return *this; + } + + /// Member access operator + CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } + + /// Member access operator + CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } + + /// Computes the dot product with anotherCoord object + CUTLASS_HOST_DEVICE + LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const { + for (int i = 0; i < kRank; ++i) { + sum += idx[i] * b.idx[i]; + } + return sum; + } + + /// Gets the index of a given Coord element + template + CUTLASS_HOST_DEVICE Index& at() { + return idx[Dim]; + } + + /// Access via index; may limit unrolling potential + CUTLASS_HOST_DEVICE + Index& at(int dim) { return idx[dim]; } + + /// Gets the index of a given Coord element + template + CUTLASS_HOST_DEVICE Index const& at() const { + return idx[Dim]; + } + + /// Access via index; may limit unrolling potential + CUTLASS_HOST_DEVICE + Index const& at(int dim) const { return idx[dim]; } + + /// Determines if two Coord<> objects are equal + CUTLASS_HOST_DEVICE + bool operator==(Coord const& b) const { + bool equal = true; + for (int i = 0; equal && i < kRank; ++i) { + equal = (idx[i] == b.idx[i]); + } + return equal; + } + + /// Not equal + CUTLASS_HOST_DEVICE + bool operator!=(Coord const& b) const { return !(*this == b); } + + /// Clamps a coordinate to a range specified by maximum and minimum values + CUTLASS_HOST_DEVICE + Coord& clamp(Coord const& max, Coord const& min = Coord()) { + for (int i = 0; i < kRank; ++i) { + idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); + } + return *this; + } + + /// Returns the sum of all elements + CUTLASS_HOST_DEVICE + Index sum() const { + Index sum_(idx[0]); + for (int i = 1; i < kRank; ++i) { + sum_ += idx[i]; + } + return sum_; + } + + /// Returns the product of all elements + CUTLASS_HOST_DEVICE + LongIndex product() const { + LongIndex product_(idx[0]); + for (int i = 1; i < kRank; ++i) { + product_ *= idx[i]; + } + return product_; + } + + /// Less than operator + CUTLASS_HOST_DEVICE + bool operator<(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] < b[i])) { + return false; + } + } + return true; + } + + /// Less than or equals operator + CUTLASS_HOST_DEVICE + bool operator<=(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] <= b[i])) { + return false; + } + } + return true; + } + + /// Greater than operator + CUTLASS_HOST_DEVICE + bool operator>(Coord const &b) const { + return !(*this <= b); + } + + /// Greater than or equals operator + CUTLASS_HOST_DEVICE + bool operator>=(Coord const &b) const { + return !(*this < b); + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Index s, Coord coord) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + +/// Scalar multiplication +template +CUTLASS_HOST_DEVICE +Coord operator*(Coord coord, Index s) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] *= s; + } + return coord; +} + +/// Scalar division +template +CUTLASS_HOST_DEVICE +Coord operator/(Index s, Coord coord) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = s / coord[i]; + } + return coord; +} + +/// Scalar division +template +CUTLASS_HOST_DEVICE +Coord operator/(Coord coord, Index s) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] /= s; + } + return coord; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Integer-valued make_Coord +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to make a 1-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<1, T> make_Coord(T _0) { + T values[1] = {_0}; + return Coord<1, T>(values); +} + +/// Helper to make a 2-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<2, T> make_Coord(T _0, T _1) { + T values[2] = {_0, _1}; + return Coord<2, T>(values); +} + +/// Helper to make a 3-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<3, T> make_Coord(T _0, T _1, T _2) { + T values[3] = {_0, _1, _2}; + return Coord<3, T>(values); +} + +/// Helper to make a 4-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) { + T values[4] = {_0, _1, _2, _3}; + return Coord<4, T>(values); +} + +/// Helper to make a 5-element coordinate +template +CUTLASS_HOST_DEVICE +Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) { + T values[5] = {_0, _1, _2, _3, _4}; + return Coord<5, T>(values); +} + +/// Helper to make a 1-element coordinate +template +CUTLASS_HOST_DEVICE +Coordmake_Coord_with_padding(T _0) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = N - 1; i > 0; --i) { + coord[i] = 0; + } + + coord[0] = _0; + + return coord; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h new file mode 100644 index 0000000000000000000000000000000000000000..046b3063a8ca7e1b79248ddea8d10af239eb4bdb --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/core_io.h @@ -0,0 +1,328 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helpers for printing cutlass/core objects +*/ +#pragma once + +#include +#include + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix.h" +#include "cutlass/quaternion.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm_enumerated_types.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Output operator for CUDA built-in dim3 type +inline std::ostream &operator<<(std::ostream &out, dim3 d) { + return out << d.x << ", " << d.y << ", " << d.z; +} + +/// Output operator for CUDA built-in error type +inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { + return out << cudaGetErrorString(error); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline +std::ostream& operator<<(std::ostream& out, Array const& v) { + for (int i = 0; i < Rank; ++i) { + out << (i ? ", " : "") << v[i]; + } + return out; +} + +template +inline +std::ostream& operator<<(std::ostream& out, Coord const& coord) { + for (int i = 0; i < Rank; ++i) { + out << (i ? ", " : "") << coord[i]; + } + return out; +} + +inline +std::istream & operator>>(std::istream &stream, half_t &x) { + float tmp; + stream >> tmp; + x = static_cast(tmp); + return stream; +} + +inline +std::ostream & operator<<(std::ostream &out, half_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, bfloat16_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, tfloat32_t const &x) { + return out << float(x); +} + + +inline +std::ostream & operator<<(std::ostream &out, float_e2m1_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e2m1_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_e3m2_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_e2m3_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e3m2_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, detail::float_e2m3_unpacksmem_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_ue8m0_t const &x) { + return out << float(x); +} + +inline +std::ostream & operator<<(std::ostream &out, float_ue4m3_t const &x) { + return out << float(x); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to enable formatted printing of CUTLASS scalar types to an ostream +template +struct ScalarIO { + + /// Value to print + T value; + + /// Default ctor + ScalarIO() { } + + /// Constructs from a value + ScalarIO(T value): value(value) {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default printing to ostream +template +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << scalar.value; +} + +/// Printing to ostream of int8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << int(scalar.value); +} + +/// Printing to ostream of uint8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << unsigned(scalar.value); +} + + +/// Default printing to ostream for MatrixShape +template +inline +std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { + out << "cutlass::MatrixShape::(kRow, kColumn) {" + << cutlass::MatrixShape::kRow <<"," + << cutlass::MatrixShape::kColumn <<"}"; + return out; +} + + +/// Prints matrix to ostream +template +std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { + + for (int i = 0; i < Rows; ++i) { + for (int j = 0; j < Columns; ++j) { + ScalarIO element(rhs.at(i, j)); + out << (j ? ", " : "") << element; + } + out << "\\n"; + } + + return out; +} + +template +std::ostream &operator<<(std::ostream &out, Quaternion const &rhs) { + + out << ScalarIO(rhs.w()) << " "; + if (rhs.x() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.x()) << "*i "; + if (rhs.y() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.y()) << "*j "; + if (rhs.z() >= 0) { + out << "+"; + } + + out << ScalarIO(rhs.z()) << "*k"; + + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::gemm namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace gemm { + +/// Default printing to ostream for GemmShape +template +inline +std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { + out << "cutlass::gemm::GemmShape::(kM, kN, kK) {" + << cutlass::gemm::GemmShape::kM <<"," + << cutlass::gemm::GemmShape::kN <<"," + << cutlass::gemm::GemmShape::kK << "}"; + return out; +} + +/// Default printing to ostream for GemmCoord +inline +std::ostream & operator<<(std::ostream &out, GemmCoord const &gemm_coord) { + out << "cutlass::gemm::GemmCoord {" + << gemm_coord.m() <<"," + << gemm_coord.n() <<"," + << gemm_coord.k() << "}"; + return out; +} + +} //namespace gemm +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default printing to ostream for PitchLinearShape +template < int Contiguous, int Strided> +inline +std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { + out << "cutlass::PitchLinearShape:(kContiguous, kStrided) {" + << cutlass::layout::PitchLinearShape::kContiguous <<"," + << cutlass::layout::PitchLinearShape::kStrided <<"}"; + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// stream operators for cutlass::conv namespace // +/////////////////////////////////////////////////////////////////////////////////////////////////// +namespace conv { +/// Default printing to ostream for Conv2dProblemSize +inline +std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { + out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl + << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl + << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "groups: (" << problem.groups << ")" << std::endl + << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl + << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl + << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl + << "split_k_slices: (" << problem.split_k_slices << ")" << std::endl + << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; + + return out; +} + + +/// Default printing to ostream for Conv3dProblemSize +inline +std::ostream& operator<<(std::ostream& out, Conv3dProblemSize const& problem) { + out << "NDHWC: (" << problem.N << ", " << problem.D << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl + << "KTRSC: (" << problem.K << ", " << problem.T << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl + << "NZPQK: (" << problem.N << ", " << problem.Z << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "pad_d, pad_h, pad_w: (" << problem.pad_d << ", " << problem.pad_h << ", " << problem.pad_w << ")" << std::endl + << "stride_d, stride_h, stride_w: (" << problem.stride_d << ", " << problem.stride_h << ", " << problem.stride_w << ")" << std::endl + << "dilation_d, dilation_h, dilation_w: (" << problem.dilation_d << ", " << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl + << "split_k_slices: (" << problem.split_k_slices << ") " << std::endl + << "mode: (" << ((problem.mode==conv::Mode::kConvolution) ? "conv" : "xcross") << ")"; + + return out; +} + +} // namespace conv +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a8af62be2d3e27ccf499acaead03dc3aadd4c151 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cuda_host_adapter.hpp @@ -0,0 +1,428 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Interface between a CUTLASS device-wide operator and CUDA. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" + +#include "cutlass/platform/platform.h" +#if ! defined(__CUDACC_RTC__) +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// NVRTC doesn't need definitions for these host classes + +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) \ + && !defined(__CUDACC_RTC__) +#define CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)) +#define CUDA_HOST_ADAPTER_TENSORMAP_ENABLED +#endif + +// Include for CUDA Driver API calls if any of these capabilities are enabled. +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || \ + defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +#include + +#endif // defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || + // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Macro-level guard for CUDA Host Adapter +// +#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) +#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +#if !defined(__CUDACC_RTC__) + +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +#include +#endif // (__CUDACC_VERSION__ >= 11.8) + +#include + +#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok + +#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + return func(args...); \ + } + +#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ > 12) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, ver, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#else + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPoint( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#endif // (__CUDACC_VER_MAJOR__ > 12) + +#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12) +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000); +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); +#endif + +#undef CUTLASS_CUDA_DRIVER_STRINGIFY + +#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func + +#endif // !defined(__CUDACC_RTC__) + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter +/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute +/// is not introduced. +struct CudaHostLaunchAttributes { + +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + + /// Reasonable maximum launch attributes that are commonly applied + static constexpr int32_t kMaximumAttributeCount = 5; + + /// Launch attributes + CUlaunchAttribute launch_attributes[kMaximumAttributeCount]; + int32_t attribute_count = 0; + + CUTLASS_HOST_DEVICE + CudaHostLaunchAttributes(CUlaunchAttribute *launch_attributes_ = nullptr, + int32_t attribute_count_ = 0) { + CUTLASS_ASSERT(attribute_count_ >= 0 && attribute_count_ < kMaximumAttributeCount); + for (int32_t i = 0; i < attribute_count_ && i < kMaximumAttributeCount; ++i) { + launch_attributes[i] = launch_attributes_[i]; + } + attribute_count = attribute_count_; + } + + CUTLASS_HOST_DEVICE + CUlaunchAttribute const* data() const { + return launch_attributes; + } + + CUTLASS_HOST_DEVICE + size_t size() const { + return attribute_count; + } + +#endif // (CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + +}; + + +/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and +/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. +struct CudaHostAdapter { + + /// Limit the number of kernels + static constexpr int32_t kMaximumKernelCount = 4; + + /// Maximum cluster size + static constexpr int MaxClusterSize = 32; + + // + // Data members + // + + /// Handles + void *kernel_handles[kMaximumKernelCount]; + int32_t kernel_count = 0; + + CudaHostLaunchAttributes launch_attributes; + + // + // Methods + // + + /// Ctor + CudaHostAdapter() = default; + + /// Dtor + virtual ~CudaHostAdapter() = default; + + /// Copy Ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(const CudaHostAdapter & rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(rhs.launch_attributes) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + /// Copy Assignment + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + + launch_attributes = rhs.launch_attributes; + + return *this; + } + + + /// Move ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(CudaHostAdapter && rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(std::move(rhs.launch_attributes)) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + // / Move assignment + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(CudaHostAdapter && rhs) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + launch_attributes = std::move(rhs.launch_attributes); + return *this; + } + + /// Ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(void **kernel_handles_, + int32_t kernel_count_, + CudaHostLaunchAttributes const &launch_attributes_ = { }) + : kernel_count(kernel_count_), + launch_attributes(launch_attributes_) { + CUTLASS_ASSERT(kernel_count >= 0 && kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = kernel_handles_[i]; + } + } + + /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) + CUTLASS_HOST_DEVICE + bool empty() const { return !kernel_count; } + + /// Returns kernel_count + CUTLASS_HOST_DEVICE + size_t size() const { return static_cast(kernel_count); } + + /// Queries the occupancy of a kernel + virtual Status query_occupancy( + int32_t *device_sms, + int32_t *sm_occupancy, + int32_t kernel_index, + int32_t thread_count, + int32_t smem_size) const = 0; + + /// Launches a kernel without using Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + /// This API is for preferred cluster launch; a preferred and a fallback cluster shapes are + /// considered for launch respectively. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const fallback_cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + + +#if defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + + /// Create a tensor map descriptor object representing im2col memory region. + virtual CUresult tensorMapEncodeIm2col ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const int* pixelBoxLowerCorner, + const int* pixelBoxUpperCorner, + cuuint32_t channelsPerPixel, + cuuint32_t pixelsPerColumn, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Create a tensor map descriptor object representing tiled memory region. + virtual CUresult tensorMapEncodeTiled ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Modify an existing tensor map descriptor with an updated global address. + virtual CUresult tensorMapReplaceAddress( + CUtensorMap* tensorMap, + void* globalAddress) const = 0; + +#endif // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +protected: + + /** + * Fills a buffer in Global Memory with a byte sequence copied from host memory. + * This function can be overridden to dispatch to the appropriate cuMemsetD*Async API + */ + virtual Status memsetDeviceImpl( + void* destination, ///< Device memory pointer to be filled + void const* fill_value, ///< Value to be filled in the buffer + size_t fill_size, ///< Size of the data type to be used for filling the buffer + size_t count, ///< Number of elements of size fill_size + cudaStream_t stream) const = 0; + +public: + + /// Fills a buffer in Global Memory with a byte sequence copied from host memory + template + CUTLASS_HOST_DEVICE + Status memsetDevice( + void* destination, + FillValueType fill_value, + size_t count, + cudaStream_t stream) const { + return this->memsetDeviceImpl( + destination, + &fill_value, + sizeof(FillValueType), + count, + stream); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h new file mode 100644 index 0000000000000000000000000000000000000000..c68a3ba38cb554278e692d012ca2a93b547e08f1 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/cutlass.h @@ -0,0 +1,165 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic include for CUTLASS. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" + +#if (__CUDACC_VER_MAJOR__ >= 13) + #define CUDA_STD_HEADER(header) +#else + #define CUDA_STD_HEADER(header) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/// Status code returned by CUTLASS operations +enum class Status { + kSuccess, ///< Operation was successful. + kErrorMisalignedOperand, ///< operands fail alignment requirements. + kErrorInvalidDataType, ///< DataType fails requirement. + kErrorInvalidLayout, ///< Layout fails alignment requirement. + kErrorInvalidProblem, ///< Specified problem size is not supported by operator. + kErrorNotSupported, ///< Operation is not supported on current device. + kErrorWorkspaceNull, ///< The given workspace is null when it is required to be non-null. + kErrorInternal, ///< An error within CUTLASS occurred. + kErrorArchMismatch, ///< CUTLASS runs on a device that it was not compiled for. + kErrorInsufficientDriver, ///< CUTLASS runs with a driver that is too old. + kErrorMemoryAllocation, ///< Kernel launch failed due to insufficient device memory. + kInvalid ///< Status is unspecified. +}; + +/// Convert cutlass status to status strings +CUTLASS_HOST_DEVICE +static char const* cutlassGetStatusString(cutlass::Status status) { + switch (status) { + case cutlass::Status::kSuccess: + return "Success"; + case cutlass::Status::kErrorMisalignedOperand: + return "Error Misaligned Operand"; + case cutlass::Status::kErrorInvalidDataType: + return "Error Invalid Data Type"; + case cutlass::Status::kErrorInvalidLayout: + return "Error Invalid Layout"; + case cutlass::Status::kErrorInvalidProblem: + return "Error Invalid Problem"; + case cutlass::Status::kErrorNotSupported: + return "Error Not Supported"; + case cutlass::Status::kErrorWorkspaceNull: + return "Error Workspace Null"; + case cutlass::Status::kErrorInternal: + return "Error Internal"; + case cutlass::Status::kErrorInsufficientDriver: + return "Error Insufficient Driver"; + case cutlass::Status::kErrorArchMismatch: + return "Error Architecture Mismatch"; + case cutlass::Status::kErrorMemoryAllocation: + return "Error Memory Allocation failed"; + case cutlass::Status::kInvalid: break; + } + + return "Invalid status"; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper function to return true when called by thread 0 of threadblock 0. +CUTLASS_HOST_DEVICE bool thread0() { + #if defined(__CUDA_ARCH__) + return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); + #else + return false; + #endif +} + +/// Returns a lane index in the warp. The threads in warp may not be convergent +CUTLASS_DEVICE +int canonical_lane_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x % NumThreadsPerWarp; + #else + return 0; + #endif +} + +/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_idx_sync() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); + #else + return 0; + #endif +} + +/// Returns a warp index in the CTA. The threads in warp may not be convergent +/// As it doesn't sync the warp, it faster and allows forward progress +CUTLASS_DEVICE +int canonical_warp_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x / NumThreadsPerWarp; + #else + return 0; + #endif +} + +/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_group_idx() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); + #else + return 0; + #endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a304cd6e3adae2c009c5b474a8e2920b618a3ea3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/blockwise_scale_layout.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Blockwise Scale configs specific for Blockwise/Groupwise MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" +#include "cute/arch/mma_sm90.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm1xxBlockwiseScaleConfig { + + using ShapeSFA = Shape, int32_t>, Shape, int32_t>, int32_t>; + using ShapeSFB = Shape, int32_t>, Shape, int32_t>, int32_t>; + + using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutSFA = Layout; + using LayoutSFB = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSFA{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layoutSFA(CtaShape_MNK cta_shape_mnk) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K] = cta_shape_mnk; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeM)>{})); + } + else { + return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K] = cta_shape_mnk; + return make_layout( + make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeM)>{}), + make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), + strides + ); + } + + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSFB{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layoutSFB(CtaShape_MNK cta_shape_mnk) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeN)>{})); + } + else { + return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K] = cta_shape_mnk; + return make_layout( + make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeN)>{}), + make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), + strides + ); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, SFVecSizeM))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto mk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(M, SFVecSizeM)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (majorSFB == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, SFVecSizeN))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto nk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(N, SFVecSizeN)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); + } + +}; + +template +struct RuntimeBlockwiseScaleConfig { + + using ShapeSFA = Shape, Shape, int32_t>; + using ShapeSFB = Shape, Shape, int32_t>; + + using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutSFA = Layout; + using LayoutSFB = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSFA{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSFB{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, sfm))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfm, cute::ceil_div(M, sfm)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + + if constexpr (majorSFB == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, sfn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto [sfm, sfn, sfk] = sf_vec_shape; + auto nk_layout = make_layout( + make_shape(make_shape(sfn, cute::ceil_div(N, sfn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); + } + +}; + +// Sm90 only supports MN major for SFA and SFB for now +template +using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig< + SFVecSizeM, + SFVecSizeN, + SFVecSizeK, + majorSFA == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K, + majorSFB == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K>; + +template +using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; + +template +using Sm120BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; + +template +constexpr auto sm90_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm90BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + +template +constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm100BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + +template +constexpr auto sm120_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm120BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d35765adebaa35bfcd767ff245ec72d453c28563 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/cluster.hpp @@ -0,0 +1,99 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + + + +#include "cute/container/tuple.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/trace.h" +#include "cute/layout.hpp" // cute::make_shape +#include "cutlass/trace.h" // CUTLASS_TRACE_HOST + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +// Returns either ClusterShape, if it is static, or a Shape> populated with the +// x and y dimensions of `dynamic_cluster_shape`. +template +CUTLASS_HOST_DEVICE +static auto +select_cluster_shape(ClusterShape cluster_shape, dim3 dynamic_cluster_shape) { + return cute::conditional_return>( + make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}), + cluster_shape); +} + +template +CUTLASS_DEVICE +static auto +select_cluster_shape(ClusterShape cluster_shape) { + if constexpr (cute::is_static_v) { + return cluster_shape; + } + else { + dim3 dynamic_cluster_shape = cute::cluster_shape(); + return make_shape(static_cast(dynamic_cluster_shape.x), static_cast(dynamic_cluster_shape.y), cute::Int<1>{}); + } +} + +// Dynamic cluster shape can_implement rule +template +CUTLASS_HOST_DEVICE +bool +preferred_cluster_can_implement(dim3 cluster_shape, dim3 cluster_shape_fallback) { + bool implementable{true}; + + // Runtime cluster shape should satisfy MMA requirements + auto AtomThrShapeM = cute::size<0>(AtomThrShapeMNK{}); + implementable &= (cluster_shape.x > 0 && cluster_shape.y > 0 && cluster_shape.z > 0); + implementable &= (cluster_shape.x % AtomThrShapeM == 0); + + implementable &= (cluster_shape_fallback.x > 0 && cluster_shape_fallback.y > 0 && cluster_shape_fallback.z > 0); + implementable &= (cluster_shape_fallback.x % AtomThrShapeM == 0); + + // Only support pow2 runtime cluster shape for now + implementable &= ispow2(cluster_shape.x) && + ispow2(cluster_shape.y) && + ispow2(cluster_shape.z); + + implementable &= ispow2(cluster_shape_fallback.x) && + ispow2(cluster_shape_fallback.y) && + ispow2(cluster_shape_fallback.z); + + return implementable; +} + +} // namespace cutlass::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01085c54159fc1cd5d6b7e2ee1d40a46cccd4f67 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective.hpp @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/container/tuple.hpp" +#include "cute/layout.hpp" // cute::size(shape) +#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::MXF4Format, cute::UMMA::MXF8F6F4Format +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct deduce_mixed_width_dtype { +static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively."); + +private: + using underlying_tuple = cute::conditional_t::value, Tuple, cute::tuple>; + static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v - 1); + +public: + using type = cute::conditional_t<(I < cute::tuple_size_v), + cute::tuple_element_t, + void>; +}; + +template +using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; + + + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_runtime_f8f6f4() { + return (cute::is_same_v || + cute::is_same_v || + cute::is_same_v); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_f8f6f4_inputs() { + return ( + + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + + cute::is_same_v || + cute::is_same_v + || cute::is_same_v || + cute::is_same_v || + cute::is_same_v + + ) && + ( + + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + + cute::is_same_v || + cute::is_same_v + || cute::is_same_v || + cute::is_same_v || + cute::is_same_v + + ); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm100_mma_f8f6f4() { + return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 32) && is_sm10x_f8f6f4_inputs(); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_f8f6f4_element() { + return (cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + + ); +} + + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_f4_element() { + return (cute::is_same_v + ); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_mxf8f6f4_input() { + // ElementType must be F8, F6, or F4 + return ( cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm10x_mxf4nvf4_input() { + // ElementType must be F4 + return ( cute::is_same_v || + cute::is_same_v + ); +} + +template +struct sm10x_block_scale_runtime_input_t { + static constexpr bool IsF8F6F4MmaInput = is_sm10x_mxf8f6f4_input(); + static constexpr bool IsF4MmaInput = is_sm10x_mxf4nvf4_input(); + + using Type = cute::conditional_t + >; +}; + + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm120_f8f6f4() { + return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 32) && is_sm10x_f8f6f4_inputs(); +} + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm100_sparse_f8f6f4() { + return (cute::size<2>(typename TiledMma::Shape_MNK{}) == 64) && is_sm10x_f8f6f4_inputs(); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89d250001eaa00990319cc2a1da35ec0dccb8703 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -0,0 +1,1249 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +// The universal converter +template < + class SrcType, + class DstType, + class LayoutIn, + class LayoutOut +> +struct LayoutAwareConvertImpl { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor const& src, + cute::Tensor & dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + static_assert(cute::cosize_v == cute::cosize_v); + constexpr int N = decltype(cute::max_common_vector(LayoutIn{}, LayoutOut{})){}; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using Converter = cutlass::NumericArrayConverter; + auto&& src_vm = cute::recast(src); + auto&& dst_vm = cute::recast(dst); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < src_vm.size(); ++i) { + dst_vm(i) = Converter::convert(src_vm(i)); + } + } +}; + +// Specialization for INT4 -> BF16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int4b_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (4 * (ii)); + static constexpr uint32_t xor_mask = 0x43084308; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(xor_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = xor_mask; // 0x43084308, {136, 136} + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, + reinterpret_cast(lo_bias)); + } + } + } +}; + +// Specialization for UINT4 -> BF16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::uint4b_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (4 * (ii)); + static constexpr uint32_t or_mask = 0x43004300; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(or_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = or_mask; // 0x43004300, {128, 128} + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, + reinterpret_cast(lo_bias)); + } + } + } +}; + +// Specialization for INT4 -> FP16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int4b_t, + cutlass::half_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src_reg >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + static constexpr uint32_t lo_xor_mask = 0x64086408; + static constexpr uint32_t hi_xor_mask = 0x64806480; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t hi_mask = 0x00F000F0; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(lo_mask), "n"(lo_xor_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(hi_mask), "n"(hi_xor_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t hi_bias = 0xD480D480; // {-72, -72} + static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast(lo_bias)); + } + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hi_scale), + reinterpret_cast(hi_bias)); + } + } + } +}; + +// Specialization for UINT4 -> FP16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::uint4b_t, + cutlass::half_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src_reg >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + static constexpr uint32_t or_mask = 0x64006400; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t hi_mask = 0x00F000F0; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(or_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(hi_mask), "n"(or_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = or_mask; // 0x64006400, {1024, 1024} + static constexpr uint32_t hi_bias = 0xD400D400; // {-64, -64} + static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast(lo_bias)); + } + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hi_scale), + reinterpret_cast(hi_bias)); + } + } + } +}; +/* +// Specialization for E5M2 -> FP16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::float_e5m2_t, + cutlass::half_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + // in registers: a3, a1, a2, a0 + r[RegArray::kElements - ii - 1] = src_reg << (8 * (ii)); + + static constexpr uint32_t and_mask = 0xFF00FF00; + asm volatile( + "{\n" + " and.b32 %0, %0, %1;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask)); + } + } +}; +*/ +// Specialization for INT8 -> BF16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int8_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + uint32_t tmp0, tmp1; + r[ii] = src_reg >> (8 * (ii)); + static constexpr uint32_t or_mask = 0x43004300; + static constexpr uint32_t and_mask_0 = 0x007F007F; + static constexpr uint32_t and_mask_1 = 0x00800080; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(tmp0) + : "r"(r[ii]), "n"(and_mask_0), "n"(or_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(tmp1) + : "r"(r[ii]), "n"(and_mask_1), "n"(or_mask), "n"(immLut)); + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(reinterpret_cast<__nv_bfloat162 const&>(tmp0), + reinterpret_cast<__nv_bfloat162 const&>(tmp1)); + } + } + } +}; + +// Specialization for INT8 -> FP16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int8_t, + cutlass::half_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (8 * (ii)); + static constexpr uint32_t xor_mask = 0x64806480; + static constexpr uint32_t and_mask = 0x00FF00FF; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + { + static constexpr uint32_t bias = 0x64806480; + __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast<__half2 const&>(bias)); + } + } + } +}; + +template < + class EngineIn, + class EngineOut, + class LayoutIn, + class LayoutOut +> +CUTLASS_DEVICE +void LayoutAwareConvert( // Accept mutable temporaries + cute::Tensor const& src, + cute::Tensor && dst) { + + LayoutAwareConvert(src, dst); +} +template < + class EngineIn, + class EngineOut, + class LayoutIn, + class LayoutOut +> +CUTLASS_DEVICE +void LayoutAwareConvert( + cute::Tensor const& src, + cute::Tensor & dst) { + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + Tensor src_vm = coalesce(src); + Tensor dst_vm = coalesce(dst); + Layout src_layout = src_vm.layout(); + Layout dst_layout = dst_vm.layout(); + LayoutAwareConvertImpl::convert(src_vm, dst_vm); +} + + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + namespace detail { + enum class ConversionMode { + DirectConvert, // A * B + ConvertAndScale, // (scale * A) * B + ConvertAndScaleWithZero // (scale * A + zeros) * B + }; + } // namespace detail +} //namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective::detail { + +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_logical_ptr(PointerType const* ptr) { + return cute::recast_ptr(ptr); +} +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return tile_to_shape( + layout_atom, + append(tile_shape, Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}); + } + else { + auto gmem_tile = composition(stride, tile_shape); + return make_layout_like(append(gmem_tile, make_layout(Int{}, 0))); + } +} +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_gmem_layout(Shape const& shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return make_layout(shape, stride); + } + else { + return stride; + } +} + +template +struct MixedInputUtils { +private: + using ConversionMode = cutlass::detail::ConversionMode; + using KernelSchedule = typename Collective::KernelSchedule; + using SmemLayoutA = typename Collective::SmemLayoutA; + using SmemLayoutB = typename Collective::SmemLayoutB; + using SmemLayoutScale = typename Collective::SmemLayoutScale; + using SwappedElementA = typename Collective::SwappedElementA; + using SwappedElementB = typename Collective::SwappedElementB; + using RealSwappedElementA = typename Collective::RealSwappedElementA; + using RealSwappedElementB = typename Collective::RealSwappedElementB; + using ElementScale = typename Collective::ElementScale; + using ElementZero = typename Collective::ElementZero; + using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale; + static constexpr auto KernelConversionMode = Collective::KernelConversionMode; + static constexpr auto ModeHasScales = Collective::ModeHasScales; + static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + +public: + static constexpr auto + elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto + elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale ) { + return 0; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_extra() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t + compute_tma_transaction_bytes_extra_transform() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE + static void copy_tensors_MK( + SmemTiledCopyA const& smem_tiled_copy_A, + TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, + int k_block, + int read_stage) { + + copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block)); + + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_,_,k_block,read_stage), tCrS_copy_view(_,_,k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_,_,k_block,read_stage), tCrZ_copy_view(_,_,k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + /// (Designed for separate transform pipeline in Blackwell) + /// Utilities to copy extra inputs from smem to RF + template + CUTLASS_DEVICE + static void copy_scale_zeros_for_transform( + cute::tuple & partitioned_transform_extra_info, + int load2transform_consumer_index) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); + auto&& scales = cute::get<1>(partitioned_transform_extra_info); + using ScaleType = decltype(scales); + auto tSrS = make_tensor(scales.data(), scales.layout()); + auto tSsS = cute::get<2>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto&& zeros = cute::get<3>(partitioned_transform_extra_info); + using ZeroType = decltype(zeros); + auto tZrZ = make_tensor(zeros.data(), zeros.layout()); + auto tZsZ = cute::get<4>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); + + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE + static void lookup_table_convert( // Accept mutable temporaries + Tensor const& src, + Tensor && dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + + lookup_table_convert(src, dst, scales_neg, scales_pos); + } + template + CUTLASS_DEVICE + static void lookup_table_convert( + Tensor const& src, + Tensor & dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + + constexpr int N = cute::cosize(LayoutIn{}); + static_assert(N == 4 || N == 8); + static_assert(cosize(LayoutScale{}) <= N / 4, + "at least 4 consecutive weights must share the same scale."); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + // View the input as reg + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + + // Determines if to get from the signed or unsigned candidates + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" \ + "}\n" + : "=r"(sign) + : "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut) + ); + sign = sign >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = src_reg & 0x77777777; + Tensor scales_neg_ = cute::filter(scales_neg); + Tensor scales_pos_ = cute::filter(scales_pos); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i, lut_idx >>=16, sign >>=16) { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_(i)); + auto&& scale_pos_ = reinterpret_cast const&>(scales_pos_(i)); + asm volatile( + "{\n" + " .reg .b32 pos, neg ;\n" \ + " prmt .b32 neg, %3, %4, %1 ;\n" \ + " prmt .b32 pos, %5, %6, %1 ;\n" \ + " prmt .b32 %0, pos, neg, %2 ;\n" \ + "}\n" + : "=r"(r[i]) + : "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]) + ); + } + } + + /// Utilities to dequantize A. + template + CUTLASS_DEVICE + static void static_check_scale(Layout const& tensor) { + static_assert(shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, "At least 4 adjacent weights in a thread must share the same scale."); + } + template + CUTLASS_DEVICE + static void static_check_scale(Tensor const& tensor) { + static_check_scale(flatten(Layout{})); + } + template + CUTLASS_DEVICE + static void dequantize_A_kblock( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + cute::tuple& partitioned_extra_info, + int const k_block) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + Tensor src = tCrA_load(_, _, k_block); + Tensor dst = tCrA_mma(_, _, k_block); + + CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()), + "The first mode of tensor src must be contiguous in memory"); + // try to make the size of the first mode equal to 32bit + int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, + ceil_div(32, sizeof_bits_v)); + Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, Int{})); + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, Int{})); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } + } + else if constexpr (UseScaleLookupTable) { + constexpr int num_elements = decltype(size(src))::value; + static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); + static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); + static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); + + Tensor tCrS_neg = cute::get<1>(partitioned_extra_info); + auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed + Tensor scales_neg = tCrS_neg(_, _, k_block); + Tensor scales_pos = tCrS_pos(_, _, k_block); + CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg)); + + static_check_scale(scales_neg); + static_check_scale(scales_pos); + Tensor scales_neg_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_neg, Int{})); + Tensor scales_pos_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_pos, Int{})); + + if (k_block == 0) { + Tensor scales_neg_vm_ = filter(scales_neg_vm); + Tensor scales_pos_vm_ = filter(scales_pos_vm); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) + { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); + auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ + "}\n" + : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i)); + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) *= scales_vm(j, i); + } + } + } + else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) *= scales_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); + Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i); + } + } + } + else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + + /// (Designed for separate transform pipeline in Blackwell) + /// Utilities to dequantize A. + template + CUTLASS_DEVICE + static void dequantize_A_kblock_for_transform( + Tensor const& tArA, + Tensor& tArACompute, + cute::tuple const& partitioned_extra_info, + int const k_block) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto src = tArA(_, _, _, k_block); + auto dst = tArACompute(_, _, _, k_block); + constexpr int num_elements = decltype(size(src))::value; + + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int DstElementsPerReg = 32 / sizeof_bits_v; + using RegArray = cutlass::AlignedArray; + + auto src_arr = recast(src); + auto dst_arr = recast(dst); + + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); + + cute::transform(src_arr, dst_arr, Converter::convert); + + if constexpr (ModeHasScales) { + + auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block); + + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) { + + using ScaleArray = cutlass::Array; + auto scale_arr = recast(filter_zeros(scales)); + + if constexpr (is_same_v){ + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& scale_reg = cute::recast(scales_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hmul2(bf16x2_val, + reinterpret_cast(scale_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); + } + } + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Do Nothing + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + + auto const& zeros = cute::get<3>(partitioned_extra_info)(_,_,_,k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) { + using ZeroArray = cutlass::Array; + auto zero_arr = recast(filter_zeros(zeros)); + + if constexpr (is_same_v) { + Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& zero_reg = cute::recast(zeros_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hadd2(bf16x2_val, + reinterpret_cast(zero_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{}); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } +} + + + /// Utilities for any additional inputs inside of the TMA load + template < + class Params, + class TensorStorage, + class... Ts + > + CUTLASS_DEVICE + static auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template < + class ThreadMma, + class TensorStorage + > + CUTLASS_DEVICE + static auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (UseScaleLookupTable) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); + } + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + template < + class TiledMma, + class TiledCopy, + class TensorStorage + > + CUTLASS_DEVICE + static auto partition_extra_transform_info( + TiledMma const& tiled_mma, + TiledCopy const& smem_tiled_copy_S, + TensorStorage& shared_storage) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(threadIdx.x % 128); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = cta_mma.partition_A(sS); + Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); + Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = cta_mma.partition_A(sZ); + Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); + Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + static auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } +}; + +} // cutlass::gemm::collective::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp new file mode 100644 index 0000000000000000000000000000000000000000..04120a41ae0f404ed22ed05d08f138526b8e9fc3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/collective/sm103_kernel_type.hpp @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Kernel type definitions specific for SM103 BlockScaled MMA +*/ + +#pragma once + +namespace cutlass::sm103::detail { + +enum class KernelPrefetchType { + TmaPrefetch, // TMA Prefetch (is the default version) + Disable // Disable Prefetch +}; + +} // namespace cutlass::sm103::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d2dd6a16a67c12beece2645bf4781b820e07e78e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/dependent_false.hpp @@ -0,0 +1,86 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +/// @brief A bool constant that depends on one or more template parameters. +/// +/// For more detailed documentation and use cases, +/// please see `dependent_false` below. +template +inline constexpr bool dependent_bool_value = Value; + +/// @brief An always-false value that depends on one or more template parameters. +/// +/// This exists because `static_assert(false);` always fails, +/// even if it occurs in the `else` branch of an `if constexpr`. +/// The following example shows how to use `dependent_false` in that case. +/// +/// @code +/// template +/// void foo (T t) +/// { +/// if constexpr (std::is_integral_v) { +/// do_integer_stuff(t); +/// } +/// else if constexpr (std::is_floating_point_v) { +/// do_floating_point_stuff(t); +/// } +/// else { +/// static_assert(dependent_false, "T must be " +/// "an integral or floating-point type."); +/// } +/// } +/// @endcode +/// +/// This implements the C++ Standard Library proposal P1830R1. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +/// +/// That proposal is under review as of 2022/12/05. +/// The following link shows P1830's current review status. +/// +/// https://github.com/cplusplus/papers/issues/572 +/// +/// P2593R0 proposes an alternate solution to this problem, +/// that would change the C++ language itself. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +/// +/// For headers in this library, however, we only consider library solutions +/// as work-arounds for future C++ features. +template +inline constexpr bool dependent_false = dependent_bool_value; + +} // end namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cf9b803b27b3b148e1441260471ecab99e82bfd3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/helper_macros.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Helper macros for the CUTLASS library +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +#ifdef CUTLASS_NAMESPACE +#define concat_tok(a, b) a ## b +#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) +#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#define CUTLASS_DEVICE inline +#endif + +#if ! defined(_MSC_VER) +#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline)) +#else +#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]] +#endif + +#define CUTLASS_HOST __host__ +#define CUTLASS_GLOBAL __global__ static + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) +{ } + +#if defined(__GNUC__) + #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) +#else + #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) +#endif + +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } + #else + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif +#else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) + #endif +#endif + +// CUTLASS_CMATH_NAMESPACE is the namespace where code can find +// functions like isnan and log. Such functions are in +// the std namespace in host code, but in the global namespace +// in device code. +// +// The intended use case for this macro is in "using" declarations +// for making argument-dependent lookup (ADL) work in generic code. +// For example, if T is cutlass::half_t, the following code will +// invoke cutlass::isnan(half_t). If T is float, it will invoke +// std::isnan on host and ::isnan on device. (CUTLASS's support +// for NVRTC prevents it from using things in the std namespace +// in device code.) Correct use of "using" declarations can help +// avoid unexpected implicit conversions, like from half_t to float. +// +// template +// bool foo(T x) { +// using CUTLASS_CMATH_NAMESPACE :: isnan; +// return isnan(x); +// } +// +// Without this macro, one would need to write the following. +// +// template +// bool foo(T x) { +// #if defined(__CUDA_ARCH__) +// using ::isnan; +// #else +// using std::isnan; +// #endif +// return isnan(x); +// } + +#if defined(__CUDA_ARCH__) +# define CUTLASS_CMATH_NAMESPACE +#else +# define CUTLASS_CMATH_NAMESPACE std +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 +#endif + + +// CUDA 10.1 introduces the mma instruction +#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_ASSERT(x) assert(x) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. +#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) + #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) + #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") + #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") + #else + #define CUTLASS_PRAGMA_UNROLL #pragma unroll + #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + #endif + + #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + +#else + + #define CUTLASS_PRAGMA_UNROLL + #define CUTLASS_PRAGMA_NO_UNROLL + #define CUTLASS_GEMM_LOOP + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +#define CUTLASS_THREAD_LOCAL thread_local +#else +#define CUTLASS_THREAD_LOCAL +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(_MSVC_LANG) +# define CUTLASS_CPLUSPLUS _MSVC_LANG +#else +# define CUTLASS_CPLUSPLUS __cplusplus +#endif + +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/n4762.pdf +// Section 14.8 Predefined macro names +#if (201703L <= CUTLASS_CPLUSPLUS) +#define CUTLASS_CONSTEXPR_IF_CXX17 constexpr +#define CUTLASS_CXX17_OR_LATER 1 +#else +#define CUTLASS_CONSTEXPR_IF_CXX17 +#define CUTLASS_CXX17_OR_LATER 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// __CUDA_ARCH_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_CONDITIONAL) + +#if defined(__CUDA_ARCH_SPECIFIC__) +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (__CUDA_ARCH_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_CONDITIONAL(ARCH_XXYY) (false) +#endif + +#endif + +// __CUDA_ARCH_FAMILY_SPECIFIC__ is introduced in CUDA 12.9 +#if !defined(CUDA_ARCH_FAMILY) + +#if defined(__CUDA_ARCH_FAMILY_SPECIFIC__) +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (__CUDA_ARCH_FAMILY_SPECIFIC__ == ARCH_XXYY) +#else +#define CUDA_ARCH_FAMILY(ARCH_XXYY) (false) +#endif + +#endif + +#if !defined(CUDA_ARCH_CONDITIONAL_OR_FAMILY) +#define CUDA_ARCH_CONDITIONAL_OR_FAMILY(ARCH_XXYY) \ + (CUDA_ARCH_CONDITIONAL(ARCH_XXYY) || CUDA_ARCH_FAMILY(ARCH_XXYY)) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +}; // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e1c1bd6c5529ccb7af5c70e558be327c81396106 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/layout.hpp @@ -0,0 +1,434 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/swizzle.hpp" // cute::Swizzle +#include "cute/swizzle_layout.hpp" // cute::get_swizzle_portion +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_tma.hpp" +#include "cute/arch/copy_sm100_tma.hpp" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/detail/collective.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each cutlass::layout, provides its corresponding cute stride types, 64b by default + +template +struct TagToStrideA { + using type = L; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::ColumnMajor; +}; + +template +struct TagToStrideB { + using type = L; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t>; + using tag = layout::ColumnMajor; +}; + +// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default +// Used by pointer array and grouped gemm +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [M, N, L] +template +struct TagToStrideC : TagToStrideA { }; + +// Conv: Maps to modes ((P,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,Z,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R,T), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R,T), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + +// Convenience aliases +template +using TagToStrideA_t = typename TagToStrideA::type; + +template +using TagToStrideB_t = typename TagToStrideB::type; + +template +using TagToStrideC_t = typename TagToStrideC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For 2.x compatibility APIs, provide stride->layout tag mappers + +template +constexpr bool +is_major(Stride = {}) { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; +} + +template +constexpr bool +is_major(cute::Layout = {}) { + return is_major(Stride{}); +} + +// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices +template +constexpr +auto +stride_to_layout_tag_A() { + using InternalStrideA = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_A(); + } + else if constexpr (is_major<0, StrideA>()) { // M major + return layout::ColumnMajor{}; + } + // Specialize for sparse layout + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + cute::is_same_v(InternalStrideA{}))>>) { + return layout::ColumnMajor{}; + } + else { // K major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_B() { + using InternalStrideB = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_B(); + } + else if constexpr (is_major<0, StrideB>()) { // N major + return layout::RowMajor{}; + } + else { // K major + return layout::ColumnMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_C() { + using InternalStrideC = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_C(); + } + else if constexpr (is_major<0, StrideC>()) { // M major + return layout::ColumnMajor{}; + } + else { // N major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// Utilities to map Stride back on to their corresponding layout tags +template +struct StrideToLayoutTagA { + using type = decltype(detail::stride_to_layout_tag_A()); +}; + +template +struct StrideToLayoutTagB { + using type = decltype(detail::stride_to_layout_tag_B()); +}; + +template +struct StrideToLayoutTagC { + using type = decltype(detail::stride_to_layout_tag_C()); +}; + +// Convenience aliases +template +using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; + +template +using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; + +template +using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inspects a tiled copy and whether its copy engine is TMA or not +template +constexpr bool is_tma_copy_engine() { + if constexpr (cute::is_void_v) { + return false; + } + else { + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + ) { + return true; + } + } + return false; +} + +template +struct RawDtype { using type = X; }; + +template +struct RawDtype> { using type = typename X::raw_type; }; + + +// Inspects a TiledCopy and returns its alignment in terms of element count +template +constexpr int +get_alignment_count_from_gmem_tiled_copy() { + + if constexpr (cute::is_void_v) { + return 1; + } + + // Account for ElementC = void kernels + else if constexpr (cute::is_void_v) { + return 0; + } + + else { + // For TMA tiled copies, we know the alignment has to be 128 bits + if constexpr (is_tma_copy_engine()) { + if constexpr ( cute::is_same_v::type, cutlass::detail::float_e2m1_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::float_e3m2_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::float_e2m3_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float4_unpacksmem_t> || + cute::is_same_v::type, cutlass::detail::type_erased_dynamic_float6_unpacksmem_t> || + cutlass::gemm::collective::detail::is_sm10x_f8f6f4_element() && cute::is_same_v::type, uint8_t>) { + return 128; + } + + // For sparse MMA, alignment in logical elements is increased by sparsity factor + if constexpr (cute::is_sparse_v) { + return 128 / sizeof_bits::value * ElementMma::sparsity; + } + return 128 / sizeof_bits::value; + } + else { + // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN + return GmemTiledCopy::NumValSrc; + } + } +} + +// Return alignment bit requirements for the GEMM inputs. +template < + class ElementType + , bool IsF8F6F4SubBytes=false +> +constexpr int +get_input_alignment_bits() { + if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 4) { + // 16U4 format: The inner tensor size dimension should be multiple of 64B. + return 64 * 8; + } + else if constexpr (IsF8F6F4SubBytes && sizeof_bits::value == 6) { + // 16U6 format : The inner tensor size dimension must be a multiple of 96B. + return 96 * 8; + } + // TMA 16B alignment requirement + return 128; +} + +// Return alignment bit requirements for the GEMM outputs. +template +constexpr int +get_output_alignment_bits() { + if constexpr (sizeof_bits::value == 6) { + // 16U6 format : The inner tensor size dimension must be a multiple of 96B. + return 96 * 8; + } + // TMA 16B alignment requirement + return 128; +} + +// Check if tensor layout satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(cute::Layout const& layout) { + // Condition: shape must divide by Alignment without rounding + bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast(layout)); + // Condition: every dynamic stride must be a multiple of Alignment + bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static::value || (s % Alignment == 0); }); + return shape_check && stride_check; +} + +// Check if tensor layout satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(Shape const& shape, Stride const& stride) { + return check_alignment(cute::make_layout(shape, stride)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(cute::Swizzle) { + static_assert(B >= 0 and M >= 0); + return size_t(1) << size_t(B + M + cute::abs(S)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(Layout layout) { + return alignment_for_swizzle(cute::get_swizzle_portion(layout)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..84de1c7d3c9359b94ababf184a2c2db724236b11 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mainloop Fusion configs specific for scale factors +*/ + +#pragma once + +#include // cute::void_t + +namespace cutlass::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ElementSFType { + using type = void; +}; + +template +struct ElementSFType> { + using type = typename CollectiveMainloop::ElementSF; +}; + +template +struct LayoutSFAType { + using type = void; +}; + +template +struct LayoutSFAType> { + using type = typename CollectiveMainloop::LayoutSFA; +}; + +template +struct LayoutSFBType { + using type = void; +}; + +template +struct LayoutSFBType> { + using type = typename CollectiveMainloop::LayoutSFB; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b4cbd3864a7fbfc524229cb183c62564cead1e7f --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/mma.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cute/layout.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsSparseTensorOp : cute::false_type { }; + +// TiledMma for sparse must have ValTypeE +template +struct IsSparseTensorOp> + : cute::true_type { }; + + +template +struct IsBlockScaledTensorOp : cute::false_type { }; + +// TiledMma for blockScaled must have FrgTypeSFA +template +struct IsBlockScaledTensorOp> + : cute::true_type { }; + + +// The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. +template +struct get_operator_class { + static constexpr bool is_sparse_op = IsSparseTensorOp::value; + static constexpr bool is_block_scaled_op = IsBlockScaledTensorOp::value; + // All tensorop operations have atom shape's M >= 8 + static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; + using type = cute::conditional_t< + is_tensor_op, + cute::conditional_t< + is_sparse_op, + cutlass::arch::OpClassSparseTensorOp, + cute::conditional_t< + is_block_scaled_op, + cutlass::arch::OpClassBlockScaledTensorOp, + cutlass::arch::OpClassTensorOp + > + >, + cutlass::arch::OpClassSimt + >; +}; + +template +using get_operator_class_t = typename get_operator_class::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4f20cb237cb9b6960275339ee803e00f8e40031 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_blockscaled_layout.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blocked Scale configs specific for SM100 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm1xxBlockScaledBasicChunk { + + using Blk_MN = _128; + using Blk_SF = _4; + + using SfKMajorAtom = Layout< Shape< Shape<_32,_4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; + using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_32,_4>>, + Stride, Stride<_16,_4>>>; + using SfAtom = cute::conditional_t; +}; + +template +struct Sm1xxBlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm1xxBlkScaledChunk = Sm1xxBlockScaledBasicChunk; + using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledChunk::SfAtom; + + using LayoutSF = decltype(blocked_product(SfAtom{}, make_layout( make_shape(int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0))))); + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSF{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape, class LayoutSFA = LayoutSF> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFA(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + constexpr int MMA_NSF = TiledMma::K / SFVecSize; + // Basic storage block for new Scaling Factor Layouts + using mnBasicBlockShape = Shape<_32,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; + using kBasicBlockStride = Stride<_0, _1>; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + + using TL_VMNK = typename TiledMma::ThrLayoutVMNK; + constexpr TL_VMNK tl_vmnk{}; + constexpr int MMA_M = cute::size<0>(TileShape_MNK{}) / cute::size<0>(tl_vmnk); + using mma_SFA_shape = decltype( make_shape( prepend(Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); + using mma_SFA_stride = decltype(make_stride( prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); + using sSFA_shape = decltype( make_shape( mma_SFA_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); + using sSFA_stride = decltype(make_stride(mma_SFA_stride{}, _0{}, make_stride( Int{}, Int{}))); + using SmemLayoutAtomSFA = decltype(make_layout(sSFA_shape{}, sSFA_stride{})); + return SmemLayoutAtomSFA{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFB(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + constexpr int MMA_NSF = TiledMma::K / SFVecSize; + // Basic storage block for new Scaling Factor Layouts + using mnBasicBlockShape = Shape<_32,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; + using kBasicBlockStride = Stride<_0, _1>; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + + using TL_VMNK = typename TiledMma::ThrLayoutVMNK; + constexpr TL_VMNK tl_vmnk{}; + constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + constexpr int MMA_N_SFB = cutlass::ceil_div(MMA_N, Blk_MN{}) * Blk_MN{}; + using mma_SFB_shape = decltype(make_shape( prepend( Int{}/Blk_MN{}, mnBasicBlockShape{}), kBasicBlockShape{})); + using mma_SFB_stride = decltype(make_stride(prepend( Blk_Elems{}, mnBasicBlockStride{}), kBasicBlockStride{})); + using sSFB_shape = decltype( make_shape( mma_SFB_shape{}, _1{}, make_shape( Blk_SF{}/Int{}, Int(TileShape_MNK{}) / SFVecSize / Blk_SF{}>{}))); + using sSFB_stride = decltype(make_stride(mma_SFB_stride{}, _0{}, make_stride( Int{}, Int{}))); + using SmemLayoutAtomSFB = decltype(make_layout(sSFB_shape{}, sSFB_stride{})); + return SmemLayoutAtomSFB{}; + } +}; + + +template +struct Sm1xxBlockScaledOutputConfig { + // We are creating the SFD tensors' layouts in the collective. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm1xxBlkScaledChunk = cutlass::detail::Sm1xxBlockScaledBasicChunk; + using Blk_MN = typename Sm1xxBlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledChunk::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledChunk::SfAtom; + + using LayoutKMajorSF = decltype(blocked_product(SfAtom{}, make_layout(make_shape (int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0))))); + + using LayoutMNMajorSF = decltype(blocked_product(SfAtom{}, make_layout(make_shape (int32_t(0), int32_t(0), int32_t(0)), + make_stride( _1{}, int32_t(0), int32_t(0))))); + + using LayoutSF = cute::conditional_t; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFD() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFC. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFD(ProblemShape problem_shape, LayoutSFD layout_sfc = LayoutSFD{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + if constexpr (major == UMMA::Major::K) { + return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); + } + else { + return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_1,_2,_3>{}); + } + } +}; + +//// Describe the Scalefactor Tensor without VectorSize +struct Sm1xxBlockScaledTensorConfig { + // k-major order + // The blockscaled tensor does not need to know vectorsize + using Blk_M = _128; + using Blk_N = _4; + using SfAtom = Layout< Shape< Shape<_32,_4>, Shape<_4>>, + Stride, Stride<_1>>>; + + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape(ProblemShape problem_shape) { + auto problem_shape_MNL = append<3>(problem_shape, 1); + auto [M, N, L] = problem_shape_MNL; + return tile_to_shape(SfAtom{}, make_shape(M,N,L), Step<_2,_1,_3>{}); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6c92c4d1199995029a550f61dce6a9903d7333e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm100MixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, int32_t>, Shape, int32_t>, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layout_scale(CtaShape_MN_K cta_shape_mn_k) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + + int constexpr size_MN = cute::get<0>(CtaShape_MN_K{}); + int constexpr size_K = cute::get<1>(CtaShape_MN_K{}); + + int constexpr SmemSizeMN = (SFVecSizeMN < size_MN) + ? SFVecSizeMN + : size_MN; + + int constexpr SmemSizeK = (SFVecSizeK < size_K) + ? SFVecSizeK + : size_K; + + int constexpr div_MN = cute::ceil_div(size_MN, SmemSizeMN); + int constexpr div_K = cute::ceil_div(size_K, SmemSizeK); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int{})); + } + else { + return make_stride(make_stride(_0{}, Int{}), make_stride(_0{}, _1{})); + } + }(); + + return make_layout( + make_shape(make_shape(Int{}, Int{}), + make_shape(Int{}, Int{})), + strides + ); + } + + + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ScaledInputDim scale_input_dims) { + const auto scale_input_dims_MNKL = append<3>(scale_input_dims, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = scale_input_dims_MNKL; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, SFVecSizeMN))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = scale_input_dims_MNKL; + auto mk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(MN, SFVecSizeMN)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +template +struct RuntimeMixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, Shape, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_S. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<3>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + if constexpr (majorScale == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, sfmn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfmn, cute::ceil_div(MN, sfmn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f12bac12dc898f177d95438ca10a1b060f4402ac --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm100_tmem_helper.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief TMEM Accumulator Helpers for SM100 +*/ + +#pragma once + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + + +namespace cutlass::detail{ +constexpr uint32_t TmemColMask = 0x0000'FFFF; + +template +CUTE_HOST_DEVICE +static constexpr auto find_tmem_tensor_col_offset(TmemTensor tensor) { + using namespace cute; + return cosize(recast(tensor).layout()) & TmemColMask; +} + +template +CUTE_HOST_DEVICE +static constexpr auto make_sm100_accumulator(TiledMma tiled_mma, AccumulatorShape acc_shape, EpilogueTile epilogue_tile) { + using namespace cute; + static_assert(rank(acc_shape) == 3 || (rank(acc_shape) == 4 && IsOverlappingAccum == false), + "Expect a rank >= 3 accumulator shape compatible with an SM100 tiled mma, Overlapping accumulators is only available for non-complex kernels"); + if constexpr (IsOverlappingAccum) { + Tensor accumulators_tmp = TiledMma::make_fragment_C(append(acc_shape, Int<2>{})); + return make_tensor( + accumulators_tmp.data(), + shape(accumulators_tmp), + replace<3>( + stride(accumulators_tmp), + Int<(256 - size<1>(EpilogueTile{})) * stride<0, 1>(accumulators_tmp.layout())>{})); + } else { + return TiledMma::make_fragment_C(append( + acc_shape, + Int{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) + } +} +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..300448d7cd0273ac4572b5484d5780d98576d4b5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/detail/sm103_blockscaled_layout.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blocked Scale configs specific for SM103 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm103BlockScaledBasicChunk { + + using Blk_MN = _128; + using Blk_SF = _4; + + using SfKMajorAtom = Layout< Shape< Shape< _8, _4, _4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; + using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_8, _4, _4>>, + Stride, Stride<_16,_128, _4>>>; + using SfAtom = cute::conditional_t; +}; + +template +struct Sm103BlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm103BlkScaledChunk = Sm103BlockScaledBasicChunk; + using Blk_MN = typename Sm103BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm103BlkScaledChunk::Blk_SF; + using SfAtom = typename Sm103BlkScaledChunk::SfAtom; + + using LayoutSF = decltype(tile_to_shape(SfAtom{}, make_shape(int(0),int(0),int(0)),Step<_2,_1,_3>{})); + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSF{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape, class LayoutSFA = LayoutSF> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5b1d3e5b1feb5e38ec9a57e6ee784b3e0e9b5a27 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/device_kernel.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for generic CUTLASS kernel. +*/ + +#pragma once + +#include // CUTLASS_HOST_DEVICE +#include // cutlass::arch::synclog_* +#include // uint64_t + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) && !CUTLASS_CLANG_CUDA +# define CUTLASS_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +# define CUTLASS_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTLASS_GRID_CONSTANT) +# if defined(CUTLASS_GRID_CONSTANT_ENABLED) +# define CUTLASS_GRID_CONSTANT __grid_constant__ +# else +# define CUTLASS_GRID_CONSTANT +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +template struct Type2Type { using type=T; }; +// using the simple type to replace the complex type to reduce this symbol size +template struct GetUnderlyingKernel : public Type2Type {}; +template class Wrapper > struct GetUnderlyingKernel> : public Wrapper {}; +template using GetUnderlyingKernel_t = typename GetUnderlyingKernel::type; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +void Kernel(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator op; + + op(params, *shared_storage); + cutlass::arch::synclog_print(); +} + + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +void Kernel2(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator::invoke(params, *shared_storage); + cutlass::arch::synclog_print(); + +} + + +//////////////////////////////////////////////////////////////////////////////// +// +// 3.0 specific launch +// +//////////////////////////////////////////////////////////////////////////////// + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +#ifdef __CUDACC__ +// Enclosing this in __CUDACC__ suppresses MSVC warnings. +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +#endif // __CUDACC__ +void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) +{ + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + Operator op; + op(params, smem); + cutlass::arch::synclog_print(); + +} + +//////////////////////////////////////////////////////////////////////////////// +} /// namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2bd817a5dd6e8cad0b4295ae1ff41d1f838eebf3 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_builder.hpp @@ -0,0 +1,126 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::DefaultCopy +#include // cute::is_base_of_v + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class AccLoadOp = cute::DefaultCopy, + class = void +> +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template < + class DispatchPolicy, + class FusionCallbacks, + class TileShape_MNK, + class EpilogueTile_MN, + class AccLoadOp, + class ElementAccumulator +> +struct CallbacksBuilder< + DispatchPolicy, + FusionCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp, + cute::enable_if_t> +> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_builder.inl" +#include "builders/sm100_builder.inl" +#include "builders/sm103_builder.inl" +#include "builders/sm120_builder.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..918017efa4c22da5ad673fbecb55d2c7cea4d68c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "detail.hpp" + +// +// Gemm +// +#include "default_epilogue.hpp" +#include "default_epilogue_array.hpp" +#include "epilogue_tensor_broadcast.hpp" +#include "sm70_epilogue_vectorized.hpp" +#include "sm70_epilogue_vectorized_array.hpp" +#include "sm90_epilogue_tma_warpspecialized.hpp" +#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#include "sm90_epilogue_array_tma_warpspecialized.hpp" +#include "sm100_epilogue_nosmem.hpp" +#include "sm100_epilogue_array_nosmem.hpp" +#include "sm100_epilogue_tma_warpspecialized.hpp" +#include "sm100_epilogue_array_tma_warpspecialized.hpp" +// +// Conv +// +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed34bc10719d2ad45d22d890e3275ce8046c5385 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class ElementC_, + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemElementC = cute::conditional_t, ElementD, ElementC>; // prevents void ref breakages + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + // Note: SharedStorage is unused for DefaultEpilogue + CUTLASS_HOST_DEVICE + DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + [[maybe_unused]] ResidueMNK, + int thread_idx, + [[maybe_unused]] char*) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + auto shape_MN = make_shape(M,N); + Tensor mD_crd = make_identity_tensor(shape_MN); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) + Tensor tCcD_mn = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (BLK_M,BLK_N) + Tensor tCcD = make_coord_tensor(tCcD_mn.layout()); // (VEC,THR_M,THR_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = shape_MN - cD_mn(_0{}); // (m,n) + auto residue_tCcD = shape_MN - tCcD_mn(_0{}); // (m,n) + + // Fully OOB tile + if (not elem_less(repeat_like(residue_cD, _0{}), residue_cD)) { + return; + } + + using FragCType = remove_cvref_t; + using FragDType = remove_cvref_t; + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + FragCType fragC; + bool pred = elem_less(tCcD(i), residue_tCcD); + arch::global_load(fragC, &tCgC(i), pred); + FragDType fragD = epilogue_op(accumulators(i), fragC); + arch::global_store(fragD, &tCgD(i), pred); + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + bool pred = elem_less(tCcD(i), residue_tCcD); + FragDType fragD = epilogue_op(accumulators(i)); + arch::global_store(fragD, &tCgD(i), pred); + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cab46ddcfd86ecbb2d3f1de43856f91e1002bfd --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/trace.h" + +#include "cutlass/cuda_host_adapter.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Applies an element wise operation to all elements within the fragment +// and writes them out to destination storage. +template < + class ElementC_, + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogueArray { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using GmemElementC = cute::conditional_t, ElementD, ElementC>; // prevents void ref breakages + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + DefaultEpilogueArray(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + [[maybe_unused]] ResidueMNK, + int thread_idx, + [[maybe_unused]] char*) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(blk_shape_MNK), make_coord(m_coord, n_coord)); // (BLK_M,BLK_N) + Tensor tCcD = thr_mma.partition_C(cD_mn); // (VEC,THR_M,THR_N) + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_shape(M,N))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_shape(M,N))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fb09f8b19475fdeeca844b20a158933726d2a895 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/detail.hpp @@ -0,0 +1,887 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cute/util/type_traits.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool +is_m_major() { + return cutlass::gemm::detail::is_major<0,Stride>(); +} + +template +constexpr bool +is_n_major() { + return cutlass::gemm::detail::is_major<1,Stride>(); +} + +template +constexpr bool +is_im2col() { + return cute::is_same_v> + || cute::is_same_v> + || cute::is_same_v>; +} + +template +struct sm90_is_ptr_array_tma : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma::value; + +template +struct sm90_is_ptr_array_tma_cooperative : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_cooperative : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative::value; + +template +struct sm90_is_ptr_array_tma_pingpong : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_pingpong : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong::value; + +template +struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {}; + +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm90PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm120PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; + +using cutlass::atomic_maximum; + +template +static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + +template +static constexpr bool sm90_is_cooperative_v = + cute::is_base_of_v || + sm90_is_ptr_array_tma_cooperative_v; + +template +static constexpr bool sm90_is_warp_specialized_v = + (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || + cute::is_base_of_v; + +template +static constexpr bool is_im2col_mode = + cute::is_same_v || + cute::is_same_v || + cute::is_same_v; + +template +struct EmptyStorage { + CUTLASS_HOST_DEVICE + T* data() { return nullptr; } +}; + +template +CUTLASS_HOST_DEVICE +auto get_epilogue_stride(Stride stride){ + if constexpr (cute::is_base_of_v|| + cute::is_base_of_v) { + return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); + } + else { + return stride; + } +} + +template +struct IsThreadEpilogueOpWithBias { + static constexpr bool value = false; + using type = typename ThreadEpilogueOp::ElementCompute; +}; + +template +struct IsThreadEpilogueOpWithBias > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ElementBias; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling > { + static constexpr bool value = true; +}; + +template +struct IsThreadEpilogueOpWithResidualAdd { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithResidualAdd > { + static constexpr bool value = ThreadEpilogueOp::IsResidualSupported; +}; + +template +struct IsThreadEpilogueOpWithActivation { + static constexpr bool value = false; + using type = void; +}; + +template +struct IsThreadEpilogueOpWithActivation > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ActivationFn; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaled { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaled > { + static constexpr bool value = ThreadEpilogueOp::IsPerRowScaleSupported || ThreadEpilogueOp::IsPerColScaleSupported; +}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments< + ThreadEpilogueOp, + cute::void_t> : cute::true_type {}; + +// Check if ActivationFn has 'Arguments' type defined +template +struct sm100_act_has_arguments : cute::false_type {}; + +template +struct sm100_act_has_arguments > : cute::true_type {}; + +template +struct Sm100EpilogueOpNumAccumulatorMtxs { + static constexpr int value = 1; +}; + +template +struct Sm100EpilogueOpNumAccumulatorMtxs> { + static constexpr int value = EpilogueOp::NumAccumulatorMtxs; +}; + + +// Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels +template +class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; + using LoadPipelineState = cutlass::PipelineState<0>; + constexpr static uint32_t TmaTransactionBytes = 0; + constexpr static bool RequiresTransactionBytes = false; + + using StorePipeline = cutlass::PipelineTmaStore<0>; + using StorePipelineState = cutlass::PipelineState<0>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using TensorMapStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapter( + typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorStorage& shared_tensors) + : EpilogueOp(params) { } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE auto + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx) { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] CtaTileMNK cta_tile_mnk, + [[maybe_unused]] CtaCoordMNKL cta_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] int subtile_idx=-1) + { + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] TileShapeMNK tile_shape_MNK, + [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] TensorMapC const& load_tensormap, + [[maybe_unused]] int subtile_idx=-1, + [[maybe_unused]] bool wait = false) + { + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) + { + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx) { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(cta_tile_mnk); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(cta_tile_mnk); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + [[maybe_unused]] TensorMapD const& store_tensormap, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + tile_shape_MNK, + tile_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + // Dummy methods to perform different parts of TMA/Tensormap modifications + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] ProblemShapeMNKL problem_shape, + [[maybe_unused]] int32_t next_batch, + [[maybe_unused]] int32_t warp_group_idx) { } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] int32_t warp_group_idx) { } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } +}; + + +// Wrapper class to use operator-style epilogues in sm100 TMA warp-specialized kernels +template +class Sm100TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; // 0 stage to disable smem alloc + using LoadPipelineState = cutlass::PipelineState<0>; + + using StorePipeline = cutlass::PipelineTmaStore<1>; // tma store pipe has no smem alloc + using StorePipelineState = cutlass::PipelineState<1>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using TensorMapStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { + } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + CUTLASS_DEVICE auto + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] int32_t const sm_count, + [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + bool reverse_epi_n = false) + { + // C load is performed in epilogue operator + return load_pipe_producer_state; + } + + // with Tensormap + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileShapeMNK, + class CtaTileCoordMNKL, + class MmaTileMNK, + class TiledMma, + class TensorMap + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileShapeMNK tile_shape_mnk, + CtaTileCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + [[maybe_unused]] cute::tuple const& load_tensormap_info, + bool reverse_epi_n = false) + { + // C load is performed in epilogue operator + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) + { + } + + CUTLASS_DEVICE auto + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] int32_t const sm_count, + [[maybe_unused]] int32_t const sm_idx) const { + return cute::make_tuple(nullptr); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors + ) + { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + + auto [acc_state_next] = (*this).template operator()( + acc_pipeline, + acc_pipe_consumer_state, + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + shared_tensors); + + // Let mma warp know tmem buffer is consumed and empty + ++load_pipe_consumer_state; + ++store_pipe_producer_state; + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + } + + // FastF32 API + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r) + { + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tTR_rAcc, + shared_tensors, + tiled_t2r); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + // FastF32 API with Tensor Map + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R, + class TensorMap + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, + TensorStorage& shared_tensors, + TensorMap tensormap, + TiledCopyT2R tiled_t2r) { + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tTR_rAcc, + shared_tensors, + tiled_t2r); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class TileCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TensorMap + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + TileCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors, + TensorMap tensormap + ) + { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state); + + auto [acc_state_next] = (*this).template operator()( + acc_pipeline, + acc_pipe_consumer_state, + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + shared_tensors); + + // Let mma warp know tmem buffer is consumed and empty + ++load_pipe_consumer_state; + ++store_pipe_producer_state; + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_state_next); + } + + template + CUTLASS_DEVICE void + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + [[maybe_unused]] LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state, + [[maybe_unused]] CtaTileMNK cta_tile_mnk) + { + } + + // Dummy methods to perform different parts of TMA/Tensormap modifications + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] ProblemShape problem_shape, + [[maybe_unused]] int32_t next_batch) { } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + [[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap) { } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } +}; + + +// SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. +template +struct has_beta { + static constexpr bool value = false; +}; + +template +struct has_beta> { + static constexpr bool value = true; +}; + +template +struct has_beta_ptr { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr> { + static constexpr bool value = true; +}; + +template +struct has_beta_ptr_array { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr_array> { + static constexpr bool value = true; +}; + +} // namespace detail +} // namespace collective +} // namespace epilogue +} // namespace cutlass diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d32dd6aeefe91b2663a9c9adeee3848e16f6c08f --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -0,0 +1,271 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor for performing tensor-tensor broadacasts atop existing epilogues. + + Concretely, the opeartion performed is the following: + UnaryOp( + BinaryOp1( + BinaryOp0( + Activation((alpha * A @ B) + bias), + beta * C0 + ), + beta * C1 + ) + ) + + where: + - C0 and C1 have the same extents as the output + - BinaryOp0 and BinaryOp1 perform elementwise binary operations + - UnaryOp is an elementwise operation +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collective epilogue that applies elementwise tensor-tensor operations atop other epilogues +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + bool PerColumnBias_ = false +> +class EpilogueTensorBroadcast { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename ThreadEpilogueOp::ElementBias; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static constexpr bool IsBinaryOp0Enabled = ThreadEpilogueOp::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; + + static constexpr bool PerColumnBias = PerColumnBias_; + using BiasStride = typename cute::conditional_t, Stride<_1, _0, _0>>; + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias* ptr_Bias = nullptr; + ElementC* ptr_C0 = nullptr; + ElementC* ptr_C1 = nullptr; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueTensorBroadcast(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source0_needed() || epilogue_op.is_source1_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + auto stride_bias = detail::get_epilogue_stride(BiasStride{}); + + // Represent the full output tensor + Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC1_mnl = make_tensor(make_gmem_ptr(params.ptr_C1), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), stride_bias); // (m,n,l) + + Tensor gC0_mnl = local_tile(mC0_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC1_mnl = local_tile(mC1_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this thread block is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC0 = gC0_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC1 = gC1_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC0 = thr_mma.partition_C(gC0); // (VEC,THR_M,THR_N) + Tensor tCgC1 = thr_mma.partition_C(gC1); // (VEC,THR_M,THR_N) + Tensor tCgBias = thr_mma.partition_C(gBias); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, + "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC0) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC1) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + CUTE_STATIC_ASSERT_V(size(tCgBias) == size(accumulators), + "Accumulator count must have the same destination element count."); + + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + bool bias_needed = params.ptr_Bias != nullptr; + bool c0_needed = (params.ptr_C0 != nullptr) && epilogue_op.is_source0_needed(); + bool c1_needed = (params.ptr_C1 != nullptr) && epilogue_op.is_source1_needed(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + ElementBias bias = bias_needed ? tCgBias(i) : ElementBias(0); + ElementC c0 = c0_needed ? tCgC0(i) : ElementC(0); + ElementC c1 = c1_needed ? tCgC1(i) : ElementC(0); + + tCgD(i) = epilogue_op(accumulators(i), c0, c1, bias); + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d3b2d0880e56fe65d7dda6efb982dee52f23b3e2 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -0,0 +1,937 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array and Grouped GEMM epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PtrArrayNoSmem; + using EpilogueTile = EpilogueTile_; + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + using AlignmentC = AlignmentC_; + using AlignmentD = AlignmentD_; + + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + constexpr static int ThreadCount = 128; + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + struct TensorStorage { }; + struct TensorMapStorage { }; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&) { + + using namespace cute; + using X = Underscore; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; + constexpr int VD = cute::min(AlignmentD{}, size(mclD)); + Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); + Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); + + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); + Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tDpD); t++) { + tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); + } + + constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; + constexpr int VC = cute::min(AlignmentC{}, size(mclC)); + + Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); + Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); + Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tCpC); t++) { + tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); + } + Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); + Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); + + // Detect interleaved complex fp32 kernels + [[maybe_unused]] Tensor accs = accumulators; + using ElementTmem = typename decltype(accs)::value_type; + constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; + + // 1. Load accumulators into register from tmem + // Tmem -> rmem and transformation for interleaved complex kernels + if constexpr (is_interleaved_complex_f32) { + using ElementComputeAccumulator = float; + + Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) + Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); + copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); + + // 1.1. Transform accumulators in registers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAccReal); i++) { + tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; + } + } + + // Standard tmem -> rmem epilogue + else { + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + } + + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + + // 2. Apply element-wise operation and store to gmem + // source is needed + if (epilogue_op.is_source_needed()) { + copy_if(tCpC, tTR_rC_src, tTR_rC_dst); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + + return cute::make_tuple(acc_pipe_consumer_state); + } + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledCopy + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&, + TiledCopy tiled_t2r) { + + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + ThreadEpilogueOp epilogue_op{params.thread}; + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // 2. Apply element-wise operation and store to gmem + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); + } + } + } + } + +protected: + Params const& params; +}; + +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + // Required by the gemm::kernel + using DispatchPolicy = Sm100PtrArrayNoSmem; + using ElementC = ElementC_; + using ElementD = ElementD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using StrideC = StrideC_; + using StrideD = StrideD_; + using InternalStrideC = cute::remove_pointer_t; + using InternalStrideD = cute::remove_pointer_t; + using EpilogueTile = EpilogueTile_; + using CopyOpT2R = CopyOpT2R_; + using FusionCallbacks = FusionCallbacks_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + +private: + constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported + || is_same_v; // alloc reduction buffer for custom EVTs + constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + +public: + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + array_aligned buffer; + }; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + : fusion_callbacks(params_.thread, shared_tensors.thread) + , smem_buffer_ptr(shared_tensors.buffer.data()) + , params(params_) {}; + +protected: + FusionCallbacks fusion_callbacks; + uint8_t* smem_buffer_ptr; + Params const& params; + +public: + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return fusion_implementable; + } + + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + [[maybe_unused]] SharedStorage& + ) { + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + // Wait for mma warp to fill tmem buffer with accumulator results + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_mnk); + auto cta_coord_mnk = cute::make_coord(m_coord, n_coord, k_coord, cute::Int<0>{}); + + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (is_C_load_needed) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (is_C_load_needed) { + ptr_C_l = params.ptr_C[l_coord]; + } + + int thread_idx = threadIdx.x % ThreadCount; + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); + + // Construct the EVT consumer callbacks + auto residue_cD = make_coord(M,N) - cD(_0{}); + auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); + Tensor cD_ = make_coord_tensor(cD.layout()); + Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); + constexpr bool RefSrc = false; + + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); + + Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mC, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D[l_coord])), problem_shape_mnl, stride_d); + + Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + // Register Tensor + Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); + + Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); + Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + coord_cCD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; + constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); + + auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); + constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; + constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); + + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); + + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + int(0), + EpilogueTile{}, + tiled_t2r, + cD_, + residue_cD, + tTR_cD_, + residue_tTR_cD, + tCrC, + thread_idx + }; + + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. + synchronize(); + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + + // Lambda to process a single epilogue tile + auto process_tile = [&](int epi_m, int epi_n, int iter_m, int iter_n) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_last_iteration = iter_m == NumEpiSubtilesM-1 && iter_n == NumEpiSubtilesN-1; + bool do_acc_release = is_last_iteration; + + // Adjust release condition for tmem reuse + if constexpr (ReuseTmem) { + do_acc_release = iter_m == NumEpiSubtilesM-1 && iter_n == 0; // Release on first N iteration + } + + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); + cst_callbacks.begin_loop(epi_m, epi_n); + + if constexpr (not cute::is_void_v) { + if (is_C_load_needed) { + using CVecType = uint_bit_t>; + + if constexpr (!is_same_v) { + Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast(coalesce(tCrC)); + Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); + copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } + } + } + + // Copy accumulator tile from tmem to register + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); + + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + Tensor reduction_buffer = make_tensor( + raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); + + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); + + cst_callbacks.end_loop(epi_m, epi_n); + + using VecType = uint_bit_t>; + if constexpr (!is_same_v) { + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); + copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + }; + + // Use static iteration with appropriate ordering + // When ReuseTmem is true and reverse_epi_n is true, we need reverse N iteration + auto n_seq = cute::make_int_sequence{}; + auto m_seq = cute::make_int_sequence{}; + + if constexpr (UnrollEpiLoop) { + // Fully unrolled static iteration + cute::for_each(n_seq, [&](auto I_N) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_n = I_N; + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + cute::for_each(m_seq, [&](auto I_M) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_m = I_M; + process_tile(iter_m, epi_n, iter_m, iter_n); + }); + }); + } else { + // Runtime loop with pragma unroll(1) + #pragma unroll 1 + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + #pragma unroll 1 + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + process_tile(iter_m, epi_n, iter_m, iter_n); + } + } + } + + cst_callbacks.end(); + }; + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); + return cute::make_tuple(acc_pipe_consumer_state); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For sm100 kernels requiring warp specialized epilogues +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC, + class AlignmentD +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmemWarpSpecialized, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC, + AlignmentD +> : public detail::Sm100TmaWarpSpecializedAdapter> +{ +public: + // ctor inheritance + using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f0a915d7d61411de8f1fd6158365904258bf9fd --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -0,0 +1,1526 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by Ptr-Array and Grouped Gemm epilogue. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100PtrArrayTmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100PtrArrayTmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + +private: + + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using GmemElementD = cute::conditional_t>; + using GmemElementC = cute::conditional_t; // prevents void ref breakages + static_assert(not cute::is_void_v, "GmemElementD is void"); + + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); + + constexpr static bool ReuseSmemC = ReuseSmemC_ && is_destination_supported; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + + constexpr static int StageCBits = cosize_v * sizeof_bits_v; + constexpr static int StageDBits = cosize_v * sizeof_bits_v; + constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); + constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; + constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; + + using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); + using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && MaxStageBits % sizeof_bits_v == 0 + && MaxStageBits % sizeof_bits_v == 0; + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; + constexpr static uint32_t MinTensorMapWorkspaceAlignment = 64; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_C; + cute::TmaDescriptor smem_tensormap_D; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 1; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + struct Params { + using TensorShapeC = decltype(repeat_like(append<3>(StrideC{}, _1{}), int32_t(0))); + using TensorShapeD = decltype(repeat_like(append<3>(StrideD{}, _1{}), int32_t(0))); + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + TensorShapeC{}, + append<3>(InternalStrideC{}, _0{})), + SmemLayoutStageC{}, + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + TensorShapeD{}, + append<3>(InternalStrideD{}, _0{})), + SmemLayoutStageD{}, + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Gemm Host Functions + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(CtaTileShape{})); + auto init_N = int32_t(size<1>(CtaTileShape{})); + auto init_L = 1; + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + + stride_c = args.dC; + stride_d = args.dD; + } + + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + // Tensor pointers will be fixed before the first access + ElementC const* ptr_C_first_batch = nullptr; + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, EpilogueTile{}, _1{}); + } + + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + // Tensor pointers will be fixed before the first access + ElementD* ptr_D_first_batch = nullptr; + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, EpilogueTile{}, _1{}); + } + + auto fusion_workspace = static_cast(workspace); + auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); + auto tma_descriptor_workspace = reinterpret_cast( + static_cast(workspace) + fusion_workspace_size); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), + tma_load_c, + tma_store_d, + tma_descriptor_workspace, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = cute::is_void_v ? 1 : 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count) + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + bool implementable = true; + bool fusion_implementable = true; + + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if (cute::is_void_v || args.ptr_C == nullptr) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + // + // Static Device Functions + // + + template + CUTLASS_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); + } + + template + CUTLASS_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + + // + // Non-static Device Functions + // +public: + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + CUTLASS_DEVICE auto + load_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = true; + auto load_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + return cute::make_tuple(load_tensormap); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + cute::tuple load_tensormap_info, + bool reverse_epi_n = false) { + using namespace cute; + + // Check to see if tensormaps have been replaced in gmem + if (get<1>(load_tensormap_info) /* did_batch_change */) { + tensormaps_fence_acquire(get<0>(load_tensormap_info)); + } + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); + Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + lane_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(get<0>(load_tensormap_info), *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_init( + Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + // Fetch a copy of tensormaps for the CTA from Params + constexpr bool IsEpiLoad = false; + cute::TmaDescriptor* store_tensormap = nullptr; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + // Only the first epilogue warp needs to perform TMA related operations + if (warp_idx == 0) { + store_tensormap = tensormaps_init(params, shared_tensormap, sm_count, sm_idx); + } + return cute::make_tuple(store_tensormap); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors, + cute::tuple store_tensormap_info + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // Check to see if tensormaps have been replaced in gmem + // Only the first epilogue warp needs to perform TMA related operations + if (get<1>(store_tensormap_info) /* did_batch_change */ && warp_idx == 0) { + tensormaps_fence_acquire(get<0>(store_tensormap_info)); + } + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + // Does directly store the visitor into smem. + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_rD = [&]() { + if constexpr (!IsDirectR2S) { + return make_tensor(shape(tRS_sD(_,_,_,_0{}))); + } + else{ + return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + } + }(); + + Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); + Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (a single thread from one warp issues TMA store) + bool issue_tma_store = (warp_idx == 0) && cute::elect_one_sync(); + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(get<0>(store_tensormap_info)), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; // tma_store_fn + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + if constexpr (!IsDirectR2S) { + // At present, only FP4 col output with scalefactor generation fusion would go into these branch + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy output tile from register to smem + bool issue_smem_store = issue_tmem_load; + if constexpr (is_destination_supported) { + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + }; + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + TensorStorage& shared_tensors, + TensorMapD store_tensormap, + TiledCopyT2R tiled_t2r + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + bool issue_smem_store = true; + Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); + + // Copy output tile from register to smem + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init(Params const& params, + TensorMapStorage& shared_tensormap, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* tma_desc = nullptr; + cute::TmaDescriptor* gmem_tensormap = params.tensormaps; + if constexpr (IsLoad) { + if (is_source_supported) { + tma_desc = &gmem_tensormap[sm_idx]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_C), Int<1>{}, Int<1>{}); + copy(recast(pC_tensormap), recast(sC_tensormap)); + } + __syncwarp(); + } + } else if constexpr (is_destination_supported) { + int const offset_Ddesc = cute::is_void_v ? 0 : sm_count; + tma_desc = &gmem_tensormap[sm_idx + offset_Ddesc]; + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_D), Int<1>{}, Int<1>{}); + copy(recast(pD_tensormap), recast(sD_tensormap)); + } + __syncwarp(); + } + + return tma_desc; + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& params, + int32_t next_batch) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_C, + params.ptr_C[next_batch]); + } + } + } else if constexpr (is_destination_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_D, + params.ptr_D[next_batch]); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; + + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } + } + } + else if constexpr (is_destination_supported) { + ElementD const* ptr_D = nullptr; + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D, + prob_shape, + prob_stride); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& params, + cute::TmaDescriptor const* tensormap, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormap, params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release(shared_tensormap, tensormap); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormap, + cute::TmaDescriptor const* tensormap) { + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if (is_source_supported) { + tma_desc_wait_all_fn(); + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_C); + } + } else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); + tma_descriptor_cp_fence_release(tensormap, shared_tensormap.smem_tensormap_D); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (IsLoad) { + if (is_source_supported) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } else if constexpr (is_destination_supported) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90dfb80c00b7c4c48ce74d69cca52aeea8b80baa --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -0,0 +1,856 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/conv/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +template +struct IsDefaultFusionOp { + static constexpr bool value = false; +}; + +template< + class ElementD, class ElementCompute, + class ElementC, FloatRoundStyle RoundStyle +> +struct IsDefaultFusionOp< + epilogue::fusion::LinearCombination< + ElementD, ElementCompute, ElementC, ElementCompute, RoundStyle> +> { + static constexpr bool value = true; +}; + +template< + class ElementOutput, int Count, class ElementAccumulator, + class ElementCompute, epilogue::thread::ScaleType::Kind Scale, + FloatRoundStyle Round, class ElementSource +> +struct IsDefaultFusionOp< + epilogue::thread::LinearCombination< + ElementOutput, Count, ElementAccumulator, + ElementCompute, Scale, Round, ElementSource> +> { + static constexpr bool value = true; +}; + +// Legacy direct store sm100 epilogue using thread::LinearCombination, do not expect this to be stable +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100NoSmem; + using EpilogueTile = EpilogueTile_; + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpT2R = CopyOpT2R_; + using AlignmentC = AlignmentC_; + using AlignmentD = AlignmentD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + constexpr static int ThreadCount = 128; + constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; + constexpr static bool isEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + constexpr static bool isSourceNeeded = not cute::is_void_v; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); + if constexpr (isSourceNeeded) { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + return implementable; + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params, SharedStorage&) : params(params) { }; + +protected: + Params const& params; + + // + // Non-static Device Methods + // +public: + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&) { + + using namespace cute; + using X = Underscore; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; + constexpr int VD = cute::min(AlignmentD{}, size(mclD)); + Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); + Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); + + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); + Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tDpD); t++) { + tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); + } + + constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; + constexpr int VC = cute::min(AlignmentC{}, size(mclC)); + + Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); + Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); + Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tCpC); t++) { + tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); + } + Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); + Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); + + // Detect interleaved complex fp32 kernels + [[maybe_unused]] Tensor accs = accumulators; + using ElementTmem = typename decltype(accs)::value_type; + constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; + + // 1. Load accumulators into register from tmem + // Tmem -> rmem and transformation for interleaved complex kernels + if constexpr (is_interleaved_complex_f32) { + using ElementComputeAccumulator = float; + + Tensor tAccReal = accumulators(make_coord(_,_),_0{},_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAccImag = accumulators(make_coord(_,_),_0{},_0{},_1{}); // (CTA_M,CTA_N) + Tensor tTR_tAccReal = thread_t2r.partition_S(tAccReal); // (T2R,T2R_M,T2R_N) + Tensor tTR_tAccImag = thread_t2r.partition_S(tAccImag); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccReal = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAccImag = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAccReal, tTR_rAccReal); + copy(tiled_t2r, tTR_tAccImag, tTR_rAccImag); + + // 1.1. Transform accumulators in registers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAccReal); i++) { + tTR_rAcc(i) = {tTR_rAccReal(i), tTR_rAccImag(i)}; + } + } + + // Standard tmem -> rmem epilogue + else { + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); // (T2R,T2R_M,T2R_N) + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + } + + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + + // 2. Apply element-wise operation and store to gmem + ThreadEpilogueOp epilogue_op{params.thread}; + // source is needed + if (epilogue_op.is_source_needed()) { + copy_if(tCpC, tTR_rC_src, tTR_rC_dst); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); + } + + return cute::make_tuple(acc_pipe_consumer_state); + } + + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledCopy + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK cta_tile_shape_mnk, + TileCoordMNKL cta_coord_mnkl, + cute::Tensor& tTR_rGlobAcc, // (MMA,MMA_M,MMA_N) + [[maybe_unused]] SharedStorage&, + TiledCopy tiled_t2r) { + + using namespace cute; + using X = Underscore; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_shape_mnk); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + + + // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) + + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + // 2. Apply element-wise operation and store to gmem + ThreadEpilogueOp epilogue_op{params.thread}; + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cCD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i), tTR_gC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rGlobAcc); ++i) { + if (elem_less(tTR_cCD(i), problem_shape_mnl)) { + tTR_gD(i) = epilogue_op(tTR_rGlobAcc(i)); + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Direct store sm100 epilogue supporting EVT +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + // Required by the gemm::kernel + using DispatchPolicy = Sm100NoSmem; + using ElementC = ElementC_; + using ElementD = ElementD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using StrideC = StrideC_; + using StrideD = StrideD_; + using EpilogueTile = EpilogueTile_; + using CopyOpT2R = CopyOpT2R_; + using FusionCallbacks = FusionCallbacks_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + +private: + constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported + || is_same_v; // alloc reduction buffer for custom EVTs + constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + +public: + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + array_aligned buffer; + }; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC = {}; + ElementD* ptr_D = nullptr; + StrideD dD = {}; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC = {}; + ElementD* ptr_D = nullptr; + StrideD dD = {}; + }; + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + : fusion_callbacks(params_.thread, shared_tensors.thread) + , smem_buffer_ptr(shared_tensors.buffer.data()) + , params(params_) {}; + +protected: + FusionCallbacks fusion_callbacks; + uint8_t* smem_buffer_ptr; + Params const& params; + +public: + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return fusion_implementable; + } + + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + [[maybe_unused]] SharedStorage& + ) { + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + // Wait for mma warp to fill tmem buffer with accumulator results + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + auto [M, N, K, L] = problem_shape_mnkl; + auto problem_shape_mnl = select<0,1,3>(problem_shape_mnkl); + auto cta_coord_mnl = select<0,1,3>(cta_coord_mnkl); + auto cta_tiler = take<0,2>(cta_tile_mnk); + + int thread_idx = threadIdx.x % ThreadCount; + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); + + // Construct the EVT consumer callbacks + auto residue_cD = make_coord(M,N) - cD(_0{}); + auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); + Tensor cD_ = make_coord_tensor(cD.layout()); + Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); + constexpr bool RefSrc = false; + + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); + + Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mC, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + + Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D)), make_shape(M,N,L), params.dD); + + Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + + // Register Tensor + Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); + + Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); + Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + coord_cCD, cta_tile_mnk, cta_coord_mnkl, EpilogueTile{}, tiled_t2r, thread_idx); + constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; + constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); + + auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); + constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; + constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); + + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); + + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + int(0), + EpilogueTile{}, + tiled_t2r, + cD_, + residue_cD, + tTR_cD_, + residue_tTR_cD, + tCrC, + thread_idx + }; + + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. + synchronize(); + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + + bool is_last_iteration = iter_m == size<3>(tTR_tAcc)-1 && iter_n == size<4>(tTR_tAcc)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<4>(tTR_tAcc) - 1 - iter_n; + } + do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0; + } + + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); + cst_callbacks.begin_loop(epi_m, epi_n); + + if constexpr (not cute::is_void_v) { + if (is_C_load_needed) { + using CVecType = uint_bit_t>; + + if constexpr (!is_same_v) { + Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast(coalesce(tCrC)); + Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); + copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } + } + } + + // Copy accumulator tile from tmem to register + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); + + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + Tensor reduction_buffer = make_tensor( + raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); + + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); + + cst_callbacks.end_loop(epi_m, epi_n); + + using VecType = uint_bit_t>; + if constexpr (!is_same_v) { + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); + copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + + } // for epi_m + } // for epi_n + + cst_callbacks.end(); + }; + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); + return cute::make_tuple(acc_pipe_consumer_state); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// For sm100 kernels requiring warp specialized epilogues +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class ThreadEpilogueOp_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100NoSmemWarpSpecialized, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + ThreadEpilogueOp_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_ +> : public detail::Sm100TmaWarpSpecializedAdapter> +{ +public: + // ctor inheritance + using detail::Sm100TmaWarpSpecializedAdapter>::Sm100TmaWarpSpecializedAdapter; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..412a4b7b747b60ebedfa26eec95a692b4d9adaf4 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm100_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,1299 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/trace.h" + +#include "cutlass/conv/detail.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileShape_, // (CTA_M,CTA_N,CTA_K, optional: Tile_L) + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm100TmaWarpSpecialized, + CtaTileShape_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm100TmaWarpSpecialized; + using CtaTileShape = CtaTileShape_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpT2R = CopyOpT2R_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + constexpr static int ThreadCount = 128; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + +private: + using GmemElementD = ElementD; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); + + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool is_source_supported = not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + using SmemLayoutStageC = decltype(tile_to_shape(SmemLayoutAtomC{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayoutStageD = decltype(tile_to_shape(SmemLayoutAtomD{}, product_each(shape(EpilogueTile{})), + cute::conditional_t, Step<_1,_2>>{} )); + + constexpr static int StageCBits = cosize_v * sizeof_bits_v; + constexpr static int StageDBits = cosize_v * sizeof_bits_v; + constexpr static int MaxStageBits = cute::max(StageCBits, StageDBits); + constexpr static int StrideStageC = (ReuseSmemC ? MaxStageBits : StageCBits) / sizeof_bits_v; + constexpr static int StrideStageD = (ReuseSmemC ? MaxStageBits : StageDBits) / sizeof_bits_v; + + using SmemLayoutC = decltype(cute::append<3>(SmemLayoutStageC{}, Layout, Int>{})); + using SmemLayoutD = decltype(cute::append<3>(SmemLayoutStageD{}, Layout, Int>{})); + + constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + && MaxStageBits % sizeof_bits_v == 0 + && MaxStageBits % sizeof_bits_v == 0; + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + // TMA store delay only benefits with loop unrolling + constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = StageCBits / 8; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Planar complex kernels have two accumulator copies for the real and imaginary tensors. + constexpr static int NumAccumulatorMtxs = 1; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + +private: + static constexpr auto + get_tma_epi_tile() { + return cute::transform_apply(EpilogueTile{}, seq<0,1>{}, + [] (auto epi_tiler, auto mode) { + auto cta_tiler_shape = get(CtaTileShape{}); + // Use a dynamic stride to prevent mode coalescing + auto cta_tiler_stride = repeat_like(cta_tiler_shape, 0); + auto cta_tiler = make_layout(cta_tiler_shape, cta_tiler_stride); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return shape(composition(cta_tiler, epi_tiler)); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }, + [] (auto... epi_tilers) { + return make_tile(epi_tilers...); + } + ); + } + + using TmaEpilogueTile = decltype(get_tma_epi_tile()); + + template + static constexpr auto + get_tma_load_c(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), + make_layout(problem_shape_mnl, append<3>(args.dC, _0{}))); + return make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutStageC{}, TmaEpilogueTile{}, _1{}); + } + + template + static constexpr auto + get_tma_store_d(ProblemShapeMNL const& problem_shape_mnl, Arguments const& args) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), + make_layout(problem_shape_mnl, append<3>(args.dD, _0{}))); + return make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutStageD{}, TmaEpilogueTile{}, _1{}); + } + +public: + // Device side epilogue params + struct Params { + using TMA_C = decltype(get_tma_load_c (repeat_like(append<3>(StrideC{},_1{}), int32_t(0)), Arguments{})); + using TMA_D = decltype(get_tma_store_d(repeat_like(append<3>(StrideD{},_1{}), int32_t(0)), Arguments{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + }; + + // + // Gemm Host Functions + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnl = select<0,1,3>(append<4>(problem_shape, 1)); + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + tma_load_c = get_tma_load_c(problem_shape_mnl, args); + } + + typename Params::TMA_D tma_store_d = get_tma_store_d(problem_shape_mnl, args); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_d = cutlass::detail::get_output_alignment_bits(); + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_d / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideD{}); + } + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_c = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_c / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + // + // Conv Host Functions + // + + template + static constexpr Params + to_underlying_arguments(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return to_underlying_arguments(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace); + } + + template + static size_t + get_workspace_size(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return get_workspace_size(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + + template + static cutlass::Status + initialize_workspace(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args, + void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return initialize_workspace(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement(cutlass::conv::ConvProblemShape const& problem_shape, Arguments const& args) { + return can_implement(cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape), args); + } + + // + // Static Device Functions + // + + template + CUTLASS_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(cta_tile_mnk)), EpilogueTile{})); + } + + template + CUTLASS_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK const& cta_tile_mnk) { + return get_load_pipe_increment(cta_tile_mnk); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + + // + // Non-static Device Functions + // +public: + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + bool ReuseTmem = false, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + TensorStorage& shared_tensors, + bool reverse_epi_n = false) { + using namespace cute; + + int lane_idx = canonical_lane_idx(); + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(cta_tile_mnk)); + Tensor gC = local_tile(mC, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (TMA,TMA_M,TMA_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + lane_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gC_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gC_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gC_epi) - 1 - iter_n; + } + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE void + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] StorePipeline store_pipeline, + [[maybe_unused]] StorePipelineState store_pipe_producer_state) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor accumulators, + TensorStorage& shared_tensors + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(accumulators) == 3, "Accumulators must be MMA-partitioned: [MMA, MMA_M, MMA_N]"); + static_assert(size<1>(accumulators) == 1 && size<2>(accumulators) == 1, "TiledMMA must match partitioned ShapeMN"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + // Does directly store the visitor into smem. + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tTR_rAcc = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); // (EPI_V) + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + CUTE_STATIC_ASSERT(size(tTR_rAcc) % DispatchPolicy::FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + Tensor tRR_rD_src = thread_r2r.retile_S(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tTR_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + Tensor tRS_rD = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (!IsDirectR2S) { + return make_tensor(shape(tRS_sD(_,_,_,_0{}))); + } + else{ + return thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + } + }(); + + Tensor tRR_rD_dst_frg = recast>(coalesce(tRR_rD_dst)); + Tensor tRS_rD_frg = recast>(coalesce(tRS_rD)); + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Arguments for the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for sub-128 thread T2R tiled copy + Layout tmem_warp_layout = typename decltype(make_tmem_warp_partitioner(tAcc_epi(_,_,0,0)))::TiledLayout_TV{}; + constexpr bool predicate_tmem_load = size(tmem_warp_layout) != cosize(tmem_warp_layout); + bool issue_tmem_load = true; + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // The TMA store sequence for one epilogue loop iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; // tma_store_fn + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + // Begin the wait for the accumulator results + ConsumerToken acc_wait_token = acc_pipeline.consumer_try_wait(acc_pipe_consumer_state); + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + bool do_acc_release = is_last_iteration; + + // Reverse subtile order for tmem reuse if necessary + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = size<3>(gD_epi) - 1 - iter_n; + } + do_acc_release = iter_m == size<2>(gD_epi)-1 && iter_n == 0; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if (is_first_iteration) { + // Wait for mma warp to fill tmem buffer with accumulator results + acc_pipeline.consumer_wait(acc_pipe_consumer_state, acc_wait_token); + } + + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + // Compute tmem load predication if necessary + if constexpr (predicate_tmem_load) { + // Issue tmem load if this tile's tmem subpartition is accessible by this warp + int subpart_idx = (tTR_tAcc_mn.data().dp_ / 32) % 4; + issue_tmem_load = warp_idx == subpart_idx; + } + bool issue_smem_store = issue_tmem_load; + + // Copy accumulator tile from tmem to register + if (issue_tmem_load) { + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + } + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + if constexpr (!IsDirectR2S) { + // At present, only FP4 col output with scalefactor generation fusion would go into these branch + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + tRS_rD_frg(_0{}) = cutlass::NumericArrayConverter{}(tRR_rD_dst_frg(_0{})); + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy output tile from register to smem + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + }; // epi_loop_fn + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); + } + + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + TensorStorage& shared_tensors, + TiledCopyT2R tiled_t2r + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = + conditional_return(make_coord(m_coord, n_coord), make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_coord_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<3>(gD_epi)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<2>(gD_epi)); + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesN : 1) + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + #pragma unroll(UnrollEpiLoop ? NumEpiSubtilesM : 1) + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); + + // Copy output tile from register to smem + bool issue_smem_store = true; + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template + CUTLASS_DEVICE void + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + CtaTileMNK cta_tile_mnk) { + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(cta_tile_mnk)); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2b8d84dc92fb8b1a823135b2fdc556bce9dbebc --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -0,0 +1,549 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class StrideC, + class StrideD, + class ThreadEpilogueOp, + class SmemLayout, + class CopyAtomR2S, + class TiledCopyS2R, + class CopyAtomR2G, + class EpilogueScheduleType = EpilogueSimtVectorized, + class Enable = void +> +class Epilogue { + static_assert(cute::is_same_v || + cute::is_same_v, + "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = CopyAtomR2G; + + static constexpr bool IsEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using StrideBias = cute::conditional_t(), Stride<_1,_0,int64_t>, Stride<_0,_1,int64_t>>; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + static constexpr bool IsActHasArgs = detail::IsThreadEpilogueOpWithElementwiseArguments::value; + + // Host side epilogue arguments + template + struct ThreadEpilogueOpArguments { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + }; + + template + struct ThreadEpilogueOpArguments< + ThreadEpiOp, + cute::enable_if_t::value>> { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + }; + + struct Arguments { + ThreadEpilogueOpArguments thread{}; + using StrideBias = decltype(thread.dBias); + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + template + struct ParamsType { + typename ThreadEpiOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + template + struct ParamsType< + ThreadEpiOp, + cute::enable_if_t::value>> { + typename ThreadEpiOp::Params thread{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + using Params = ParamsType; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + typename ThreadEpilogueOp::Params thread_op_args; + thread_op_args.alpha = args.thread.alpha; + thread_op_args.beta = args.thread.beta; + thread_op_args.alpha_ptr = args.thread.alpha_ptr; + thread_op_args.beta_ptr = args.thread.beta_ptr; + + if constexpr (IsActHasArgs) { + return { + thread_op_args, + args.thread.activation, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + else { + return { + thread_op_args, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), params.dBias); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gBiast = flat_divide(gBias, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBiast); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rC = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rBias = make_tensor_like(tSR_gBias); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("gBias : "); print(gBias.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rC : "); print(tSR_rC.layout()); print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + print("gBiast : "); print(gBiast.layout()); print("\n"); + print("tSR_gBias : "); print(tSR_gBias.layout()); print("\n"); + print("tSR_rBias : "); print(tSR_rBias.layout()); print("\n"); + } +#endif + + if constexpr (IsEpilogueBiasSupported) { + if (params.ptr_Bias) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); + Tensor tSR_pD_flt = cute::lazy::transform(tSR_cD_flt, [&](auto const& c){ return elem_less(c, take<0,2>(residue_mnk)); }); + + // Step 0. Copy Bias from GMEM to fragment + copy_if(tSR_pD_flt, tSR_gBias_flt, tSR_rBias_flt); + } + } + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if constexpr (IsEpilogueBiasSupported) { + Tensor tSR_rBiasmn = tSR_rBias(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i)); + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i)); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } else { + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i), tSR_rC(i)); + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5030efded1e3608d91d0dca87f9f41fff827875f --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Ptr Array Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = TiledCopyS2R; + using GmemTiledCopyD = TiledCopyS2R; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = params.dC[l_coord]; + } + stride_d = params.dD[l_coord]; + } + else { + stride_c = params.dC; + stride_d = params.dD; + } + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + Tensor tSR_rCmn = make_tensor(shape(tSR_gCmn)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rCmn(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rD(i,m,n) = epilogue_op(tSR_rAcc(i,m,n), tSR_rCmn(i,m,n)); + } + // Step 7. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..77ef3ed2defbc2f286ac3002185a2864a8b322f8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -0,0 +1,1245 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm90PtrArrayTmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; + constexpr static uint32_t MinTensorMapWorkspaceAlignment = 64; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_C; + cute::array smem_tensormap_D; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC; + ElementD ** ptr_D = nullptr; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(CtaTileMNK{})); + auto init_N = int32_t(size<1>(CtaTileMNK{})); + auto init_L = 1; + + static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + stride_c = args.dC; + stride_d = args.dD; + } + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. + ElementC const* ptr_C_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_C) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{}); + } + + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + // NOTE: Since TMA desc creation with nullptr not possible until 12.6, we use an initial address even when tensor addresses are on device. This address is never used. + ElementD const* ptr_D_first_batch = reinterpret_cast(reinterpret_cast(args.ptr_D) & 0xFFFFFFFFFFFFFFF0); // Address must be 16B-aligned + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{}); + } + + auto fusion_workspace = static_cast(workspace); + auto fusion_workspace_size = round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment); + auto tma_descriptor_workspace = reinterpret_cast( + static_cast(workspace) + fusion_workspace_size); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), + tma_load_c, + tma_store_d, + tma_descriptor_workspace, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + transaction_bytes, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + auto descriptors_shape = cute::make_shape(sm_count, Int{}); + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (size(descriptors_shape) * SizeOfCuTensorMap) + + (round_nearest(FusionCallbacks::get_workspace_size(problem_shape, args.thread), MinTensorMapWorkspaceAlignment)); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool implementable = true; + bool fusion_implementable = true; + + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if (cute::is_void_v || args.ptr_C == nullptr) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + CUTLASS_DEVICE auto + load_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + // Initialize tma for loading + constexpr bool IsLoad = true; + auto load_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0); + return load_tensormaps; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class TensorMapC, + __CUTE_REQUIRES(std::is_pointer_v) + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapC const& load_tensormap, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + LoadPipelineState last_load_producer_state = load_pipe_producer_state; + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + LoadPipelineState prior_state = load_pipe_producer_state; + + bool did_load = false; + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (is_C_load_needed) { + if (issue_tma_load) { + copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + last_load_producer_state = load_pipe_producer_state; + did_load = true; + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + + if (!fusion_callbacks.is_producer_load_needed()) { + return load_pipe_producer_state; + } + + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapD const& store_tensormap, + int subtile_idx=-1) { + + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (a single thread from one warp issues TMA store) + bool issue_tma_store = ((thread_idx / NumThreadsPerWarp) == 0) && cute::elect_one_sync(); + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( + tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); + } + } + } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output needs register shuffling before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; + // Since only one warp issues TMA store, we only need that one warp to initialize tensormaps + if (warp_idx_in_warp_group == 0) { + // Initialize tma + constexpr bool IsLoad = false; + auto store_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx); + return store_tensormaps; + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + Layout desc_layout = make_layout(make_shape(sm_count, Int{})); + + Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) + + if constexpr (IsLoad) { + if (is_source_supported) { + constexpr int C_tensormap_index = NumEpilogueWarpGroups; + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + copy(recast(pC_tensormap), recast(sC_tensormap)); + } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); + + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + else { + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + copy(recast(pD_tensormap), recast(sD_tensormap)); + } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); + } + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_batch, + int32_t warp_group_idx) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, + params.ptr_C[next_batch]); + } + } + } + else if constexpr (is_destination_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + params.ptr_D[next_batch]); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl, + int32_t warp_group_idx) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; + + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } + } + } + else if constexpr (is_destination_supported) { + ElementD const* ptr_D = nullptr; + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + prob_shape, + prob_stride); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& params, + cute::TmaDescriptor const* tensormap, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); + } + + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormaps, + cute::TmaDescriptor const* tensormap, + const int32_t warp_group_idx = 0) { + // Commit and wait for all TMA load/store instructions before updating the tensormap in gmem. + // This operation only happens when the group/batch changes between consecutive tiles. + // If there are no uncommitted instructions then tma_desc_commit_group results in an empty bulk async-group. + auto tma_desc_wait_all_fn = [] () CUTLASS_LAMBDA_FUNC_INLINE { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + }; + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + tma_desc_wait_all_fn(); + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); + } + } + else if constexpr (is_destination_supported) { + tma_desc_wait_all_fn(); + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + else { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..062b9a8b582a1a3c05407f163a0ca4b05646028a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,958 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); + } + + typename Params::TMA_D tma_store_d{}; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d, + transaction_bytes + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = cutlass::detail::check_alignment(shape, StrideD{}); + } + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + ); + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); + // Relative coordinate tensors (static) + Tensor cD = make_coord_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_coord_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple MMA tiles + CUTE_STATIC_ASSERT(epi_tile_n % mma_tile_n == 0, "MMA_TILE_N must divide EPI_TILE_N"); + } + else { + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + } + + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + ); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) CUTLASS_LAMBDA_FUNC_INLINE { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + if constexpr (epi_tile_m * epi_tile_n > mma_tile_m * mma_tile_n) { + // When the epilogue subtile is larger than the MMA tiles, loop over multiple + // MMA tiles + static constexpr int MmaMPerEpiM = epi_tile_m / mma_tile_m; + static constexpr int MmaNPerEpiN = epi_tile_n / mma_tile_n; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_in_epi = 0; mma_n_in_epi < MmaNPerEpiN; ++mma_n_in_epi) { + int mma_n = (epi_n * MmaNPerEpiN) + mma_n_in_epi; + + CUTLASS_PRAGMA_UNROLL + for (int mma_m_in_epi = 0; mma_m_in_epi < MmaMPerEpiM; ++mma_m_in_epi) { + int mma_m = (epi_m * MmaMPerEpiM) + mma_m_in_epi; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + int idx_in_epi_subtile = (mma_n_in_epi * MmaMPerEpiM + mma_m_in_epi); + + tRS_rCompute_frg(idx_in_epi_subtile) = cst_callbacks.visit( + tRS_rAcc_frg_mn(0), idx_in_epi_subtile, epi_m, epi_n); + } + } + } + else { + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2d5fd85827b2751085a78dcb241aa3cf081470d5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. + This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. +*/ + +#pragma once + +#include "sm90_epilogue_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) + class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class Sm90EpilogueTmaWarpSpecializedBiasElementwise + : public CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +private: + using Impl = + CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ + >; +public: + using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; + using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; + using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; + using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; + + // Constructor inheritance + using Impl::Impl; + + // Host side epilogue arguments + struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] + Arguments { + struct ThreadArgs { + ElementCompute alpha{1}; + ElementCompute beta{0}; + ElementCompute const *alpha_ptr{nullptr}; + ElementCompute const *beta_ptr{nullptr}; + } thread; + ElementC_ const* ptr_C{nullptr}; + StrideC_ dC{}; + ElementD_* ptr_D{nullptr}; + StrideD_ dD{}; + ElementBias const* ptr_Bias{nullptr}; + ElementT* ptr_T{nullptr}; + + CUTLASS_HOST_DEVICE + operator typename Impl::Arguments() const { + typename Impl::Arguments arguments; + arguments.thread.alpha = thread.alpha; + arguments.thread.beta = thread.beta; + arguments.thread.alpha_ptr = thread.alpha_ptr; + arguments.thread.beta_ptr = thread.beta_ptr; + if constexpr (not cute::is_void_v) { + arguments.thread.bias_ptr = ptr_Bias; + } + if constexpr (not cute::is_void_v) { + arguments.thread.aux_ptr = ptr_T; + arguments.thread.dAux = dD; + } + arguments.ptr_C = ptr_C; + arguments.dC = dC; + arguments.ptr_D = ptr_D; + arguments.dD = dD; + + return arguments; + } + }; + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca91ac19b0aadfeddcfb030ee16f03905855cd63 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/dispatch_policy.hpp @@ -0,0 +1,302 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue { + +////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////// +// +// Builder Epilogue Schedules +// +////////////////////////////////////////////////////////////////////////////// +// Pre-Hopper schedules +struct PtrArrayDefault {}; +struct EpilogueSimtVectorized {}; +struct EpiloguePtrArraySimtVectorized {}; +// Hopper direct store schedules +struct NoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecializedTransposed {}; +// Hopper TMA schedules +struct TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperative {}; +struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; }; +struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; }; +struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; +// Blackwell direct store schedules +struct NoSmemWarpSpecialized1Sm {}; +struct NoSmemWarpSpecialized2Sm {}; +struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct BlockwiseNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct BlockwiseNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayBlockwiseNoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; +// Blackwell TMA schedules +struct TmaWarpSpecialized1Sm {}; +struct TmaWarpSpecialized2Sm {}; +struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; +struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; +struct TmaWarpSpecialized1SmNvf4 final : TmaWarpSpecialized1Sm {}; +struct TmaWarpSpecialized2SmNvf4 final : TmaWarpSpecialized2Sm {}; +struct TmaWarpSpecialized1SmMxf4 final : TmaWarpSpecialized1Sm {}; +struct TmaWarpSpecialized2SmMxf4 final : TmaWarpSpecialized2Sm {}; +struct TmaWarpSpecialized1SmMxf8f6f4 final : TmaWarpSpecialized1Sm {}; +struct TmaWarpSpecialized2SmMxf8f6f4 final : TmaWarpSpecialized2Sm {}; +// Cooperative epilogue schedule for sm120 sparse kernels +struct SparseTmaWarpSpecializedCooperativeSm120 : public TmaWarpSpecializedCooperative {}; + +// DEPRECATED schedules, will be removed in next release +struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; +struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +////////////////////////////////////////////////////////////////////////////// +// +// Collective Dispatch Policies +// +////////////////////////////////////////////////////////////////////////////// + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ +> +struct Sm90PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; +}; + +// DEPRECATED policies, will be removed in next release +template< + int StagesC_, + int StagesD_, + int FragmentSize_ = 2 +> +struct Sm90TmaWarpSpecializedBiasElementwise { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; +}; + + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm100PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + static_assert(StagesC >= 1, "StagesC must be >= 1"); + static_assert(StagesD >= 1, "StagesD must be >= 1"); +}; + +struct Sm100NoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100NoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100PtrArrayNoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100PtrArrayNoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm120TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ +> +struct Sm120PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9febeec4d92d54ec02e221d028f7329c2edeea5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/callbacks.hpp @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct FusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct FusionCallbacksTraits { + using DispatchPolicy = void; + using Callbacks = T; + using Operation = FusionOperation; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; + using ElementCompute = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct FusionCallbacksTraits< + FusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Callbacks = FusionCallbacks; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..114737a9d910a458f4895212d0904e002a9aeec8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/operations.hpp @@ -0,0 +1,645 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include // cute::false_type + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct FusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; + + using ElementSource = void; + static constexpr bool IsSourceSupported = false; + static constexpr bool IsResidualSupported = false; // Source is added after activation + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + + using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; + + using ElementBlockScaleFactor = void; + static constexpr int SFVecSize = 0; + static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues + using GmemLayoutTagScalefactor = void; +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAcc : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinearCombination + : ScaledAcc { + using ElementSource = ElementSource_; + static constexpr bool IsSourceSupported = true; +}; + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombTopKSoftmaxCol + : LinearCombination { +}; + + +// D = alpha * acc + beta * C + per-row bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +// D = alpha * acc + beta * C + per-column bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// Grouped Wgrad's D = alpha * acc + beta * C with special AccFetch. +template< + class GroupsPerTile_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinearCombinationGroupedWgrad + : LinearCombination { + using GroupsPerTile = GroupsPerTile_; +}; + +// D = activation(alpha * acc + beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltAct + : LinCombPerColBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// aux = alpha * acc + beta * C + per-row bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActAux + : LinCombPerRowBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-col bias) +// aux = alpha * acc + beta * C + per-col bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActAux + : LinCombPerColBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerRowLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerRowScaleSupported = true; +}; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerColScaleSupported = true; +}; + +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColResAddPerColBiasEltAct + : PerColLinCombPerColBiasEltAct { + static constexpr bool IsResidualSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltActAmaxAux + : ScaledLinCombPerRowBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltActAmaxAux + : ScaledLinCombPerColBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsDeEltActSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxInSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +// dBias = sum of columns of D +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltActDePerRowBias + : LinCombDeEltAct { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsDePerRowBiasSupported = true; +}; + +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombBlockScaleFactor + : LinearCombination { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + +// D = activation(alpha * acc + beta * C) +// With BlockScaleFactor generation (same recipe as LinCombBlockScaleFactor). +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltActBlockScaleFactor + : LinCombEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + +// D = alpha * acc + beta * C + per-row bias +// With BlockScaleFactor generation +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasBlockScaleFactor + : LinCombPerRowBias { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = alpha * acc + beta * C + per-col bias +// With BlockScaleFactor generation. +template< + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasBlockScaleFactor + : LinCombPerColBias { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = activation(alpha * acc + beta * C + per-row bias) +// With BlockScaleFactor generation. +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActBlockScaleFactor + : LinCombPerRowBiasEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +// D = activation(alpha * acc + beta * C + per-col bias) +// With BlockScaleFactor generation. +template< + template class ActivationFn_, + int SFVecSize_, + class ElementOutput_, + class ElementCompute_, + class ElementBlockScaleFactor_, + class GmemLayoutTagScalefactor_ = cutlass::layout::RowMajor, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActBlockScaleFactor + : LinCombPerColBiasEltAct { + using ElementBlockScaleFactor = ElementBlockScaleFactor_; + static constexpr int SFVecSize = SFVecSize_; + static constexpr bool IsBlockScaleSupported = true; + using GmemLayoutTagScalefactor = GmemLayoutTagScalefactor_; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dfbb75bf00bd2160af770566c4f3970a2c7b5b10 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,1322 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#include "cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Sm100 Tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// Sm100 direct store callbacks alias to sm100 tma callbacks with 0 stages +// Additional copy atom args will be ignored in the 0-stage specializations of aux load/store nodes +template < + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100NoSmemWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized<0, 0, 0, false, false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// Sm100 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +// With Row BlockScaleFactor Generation. +template< + int SFVecsize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinearCombRowBlockScaleFactor = + Sm90EVT, // gen scalefactor + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinearCombRowBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = alpha * acc + beta * C +// With Col BlockScaleFactor Generation. +template< + int SFVecsize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinearCombColBlockScaleFactor = + Sm90EVT, // gen scalefactor + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinearCombColBlockScaleFactor::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For Ptr-Array and Grouped GEMM +// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinearCombRowBlockScaleFactorPtrArray = + Sm90EVT, // gen scalefactor + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinearCombRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + // NormConst is a single device-side constant value, its not per-batch or per-group + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// For Ptr-Array and Grouped GEMM +// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombEltActRowBlockScaleFactorPtrArray = + Sm90EVT, // gen scalefactor + Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized, + fusion::LinCombEltActBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm100LinCombEltActRowBlockScaleFactorPtrArray::type, ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltActBlockScaleFactor; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +// with row blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, + ElementScalar, + AlignmentBias, + RoundStyle + > +{ + + using Impl = + Sm100LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// D = alpha * acc + beta * C + per-row bias +// with col blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasColBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorColStore< + SFVecsize, EpilogueTile, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasColBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm100LinCombPerRowBiasColBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per_col bias +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerColBiasRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm100LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// with row blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, + ElementOutput, ElementCompute, + ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm100LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// with col blockScaled generation +template< + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerRowBiasEltActColBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorColStore< + SFVecsize, EpilogueTile, + ElementOutput, ElementCompute, + ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerRowBiasEltActColBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm100LinCombPerRowBiasEltActColBlockScaleFactor< + SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per_col bias) +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm100LinCombPerColBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm100BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, + ElementOutput, ElementCompute, + ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm100LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm100LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, CtaTileShapeMNK, EpilogueTile, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +// -------------------------------------------------------------------- +// Sm100PtrArrayNoSmemWarpSpecialized (direct-store, grouped GEMM) +// -------------------------------------------------------------------- +template < + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayNoSmemWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> + : FusionCallbacks< + // reuse the ptr-array *TMA* callbacks with 0 stages + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> { + + using Base = FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>; + + // bring ctors into scope + using Base::Base; +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a20591288ad386543c3c7f0fd399c7fe45b7f60a --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// BatchNormApply +// +// This node aims to do the batch norm apply. The procedure is described as follows: +// +// output = (input - mean) * inv_stddev * alpha + bias +// +// while: (1) input & output are 2 matrices with shape (M, N), +// which are frg_input & return value of the visit function +// +// (2) mean, inv_stddev, alpha & bias are 4 vectors with shape (N). +// which are loaded by ProducerLoadCallbacks +// +// To avoid redundant calculations in EVT, this node simplify the procedure as follows: +// +// output = input * alpha' + bias' +// +// while alpha' & bias' are 2 vectors with shape (N) calculated by mean, inv_stddev, alpha & bias +// +// The calculation among vectors is described as follows: +// +// alpha' = alpha * inv_stddev +// bias' = bias - mean * alpha' +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least + // this should just match CLC stage count + int Stages, + class CtaTileShapeMNK, + class ElementScalar, + class ElementCompute, + class ElementOutput, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm100BatchNormApply { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(cute::is_same_v>); // row vector broadcast for alpha, bias, mean & inv_stddev + + using SmemLayout = decltype(make_layout(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})))); + + using ElementCol = cute::conditional_t<(sizeof(ElementCompute) > sizeof(ElementScalar)), ElementCompute, ElementScalar>; + + struct SharedStorage { + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_alpha; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_bias; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_mean; + alignas(16) array_aligned(CtaTileShapeMNK{}) * Stages> smem_inv_stddev; + }; + + struct Arguments { + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* bias_ptr = nullptr; + ElementScalar const* mean_ptr = nullptr; + ElementScalar const* inv_stddev_ptr = nullptr; + StrideMNL dVec = {}; + }; + + struct Params { + using TMA_Vec = decltype(make_tma_atom( + SM90_TMA_LOAD{}, + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), + take<0,2>(SmemLayout{}), + take<0,2>(CtaTileShapeMNK{}))); + + TMA_Vec tma_load_alpha; + TMA_Vec tma_load_bias; + TMA_Vec tma_load_mean; + TMA_Vec tma_load_inv_stddev; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + Tensor tensor_alpha = make_tensor(make_gmem_ptr(args.alpha_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_bias = make_tensor(make_gmem_ptr(args.bias_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_mean = make_tensor(make_gmem_ptr(args.mean_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + Tensor tensor_inv_stddev = make_tensor(make_gmem_ptr(args.inv_stddev_ptr), make_layout(make_shape(size(M),N,size(L)), append<3>(args.dVec, _0{}))); + + typename Params::TMA_Vec tma_load_alpha = make_tma_atom(SM90_TMA_LOAD{}, tensor_alpha, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_bias = make_tma_atom(SM90_TMA_LOAD{}, tensor_bias, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_mean = make_tma_atom(SM90_TMA_LOAD{}, tensor_mean, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + typename Params::TMA_Vec tma_load_inv_stddev = make_tma_atom(SM90_TMA_LOAD{}, tensor_inv_stddev, take<0,2>(SmemLayout{}), take<0,2>(CtaTileShapeMNK{})); + + return Params{tma_load_alpha, tma_load_bias, tma_load_mean, tma_load_inv_stddev}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm100BatchNormApply() { } + + CUTLASS_HOST_DEVICE + Sm100BatchNormApply(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_alpha(const_cast(shared_storage.smem_alpha.data())), + smem_bias(const_cast(shared_storage.smem_bias.data())), + smem_mean(const_cast(shared_storage.smem_mean.data())), + smem_inv_stddev(const_cast(shared_storage.smem_inv_stddev.data())), + smem_col_alpha(const_cast(shared_storage.smem_alpha.data())), + smem_col_bias(const_cast(shared_storage.smem_bias.data())) { } + + Params const* params_ptr; + ElementScalar* smem_alpha; + ElementScalar* smem_bias; + ElementScalar* smem_mean; + ElementScalar* smem_inv_stddev; + ElementCompute* smem_col_alpha; + ElementCompute* smem_col_bias; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& gAlpha, GTensor&& gBias, GTensor&& gMean, GTensor&& gInvStddev, + STensor&& sAlpha, STensor&& sBias, STensor&& sMean, STensor&& sInvStddev, Params const* params_ptr) + : gAlpha(cute::forward(gAlpha)), + gBias(cute::forward(gBias)), + gMean(cute::forward(gMean)), + gInvStddev(cute::forward(gInvStddev)), + sAlpha(cute::forward(sAlpha)), + sBias(cute::forward(sBias)), + sMean(cute::forward(sMean)), + sInvStddev(cute::forward(sInvStddev)), + params_ptr(params_ptr) {} + + GTensor gAlpha; + GTensor gBias; + GTensor gMean; + GTensor gInvStddev; + + STensor sAlpha; + STensor sBias; + STensor sMean; + STensor sInvStddev; + + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if (epi_m == 0 && epi_n == 0 && issue_tma_load) { + // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size + constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * bits_to_bytes(sizeof_bits_v) * 4; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA bulk copy + int pipe_index = (load_iteration / EpiTiles) % Stages; + copy(params_ptr->tma_load_alpha.with(*full_mbarrier_ptr), gAlpha, sAlpha(_,pipe_index)); + copy(params_ptr->tma_load_bias.with(*full_mbarrier_ptr), gBias, sBias(_,pipe_index)); + copy(params_ptr->tma_load_mean.with(*full_mbarrier_ptr), gMean, sMean(_,pipe_index)); + copy(params_ptr->tma_load_inv_stddev.with(*full_mbarrier_ptr), gInvStddev, sInvStddev(_,pipe_index)); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mAlpha = params_ptr->tma_load_alpha.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mBias = params_ptr->tma_load_bias.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mMean = params_ptr->tma_load_mean.get_tma_tensor(make_shape(size(M),N,size(L))); + Tensor mInvStddev = params_ptr->tma_load_inv_stddev.get_tma_tensor(make_shape(size(M),N,size(L))); + + Tensor gAlpha = local_tile(mAlpha, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gBias = local_tile(mBias, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gMean = local_tile(mMean, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gInvStddev = local_tile(mInvStddev, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sBias = make_tensor(make_smem_ptr(smem_bias), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sMean = make_tensor(make_smem_ptr(smem_mean), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), SmemLayout{}); // (CTA_M,CTA_N,PIPE) + + auto [tCgAlpha, tCsAlpha] = tma_partition(params_ptr->tma_load_alpha, group_modes<0,2>(sAlpha), group_modes<0,2>(gAlpha)); + auto [tCgBias, tCsBias] = tma_partition(params_ptr->tma_load_bias, group_modes<0,2>(sBias), group_modes<0,2>(gBias)); + auto [tCgMean, tCsMean] = tma_partition(params_ptr->tma_load_mean, group_modes<0,2>(sMean), group_modes<0,2>(gMean)); + auto [tCgInvStddev, tCsInvStddev] = tma_partition(params_ptr->tma_load_inv_stddev, group_modes<0,2>(sInvStddev), group_modes<0,2>(gInvStddev)); + + constexpr int EpiTiles = decltype(size(ceil_div(shape(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ProducerLoadCallbacks( + cute::move(tCgAlpha), cute::move(tCgBias), cute::move(tCgMean), cute::move(tCgInvStddev), + cute::move(tCsAlpha), cute::move(tCsBias), cute::move(tCsMean), cute::move(tCsInvStddev), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + SR_RTensor&& tSR_rAlpha, SR_RTensor&& tSR_rBias, + SR_RTensor&& tSR_rMean, SR_RTensor&& tSR_rInvStddev, + SR_STensor&& tSR_sAlpha, SR_STensor&& tSR_sBias, + SR_STensor&& tSR_sMean, SR_STensor&& tSR_sInvStddev, + SR_CTensor&& tSR_cAlpha, + SR_SCTensor&& tSR_sColAlpha, SR_SCTensor&& tSR_sColBias, + RTensor&& tCrAlpha, RTensor&& tCrBias, + STensor&& tCsAlpha, STensor&& tCsBias, + ThrNum thr_num, + Params const* params_ptr) + : + tSR_rAlpha(cute::forward(tSR_rAlpha)), tSR_rBias(cute::forward(tSR_rBias)), + tSR_rMean(cute::forward(tSR_rMean)), tSR_rInvStddev(cute::forward(tSR_rInvStddev)), + tSR_sAlpha(cute::forward(tSR_sAlpha)), tSR_sBias(cute::forward(tSR_sBias)), + tSR_sMean(cute::forward(tSR_sMean)), tSR_sInvStddev(cute::forward(tSR_sInvStddev)), + tSR_cAlpha(cute::forward(tSR_cAlpha)), + tSR_sColAlpha(cute::forward(tSR_sColAlpha)), tSR_sColBias(cute::forward(tSR_sColBias)), + tCrAlpha(cute::forward(tCrAlpha)), tCrBias(cute::forward(tCrBias)), + tCsAlpha(cute::forward(tCsAlpha)), tCsBias(cute::forward(tCsBias)), + thr_num(thr_num), + params_ptr(params_ptr) {} + + SR_RTensor tSR_rAlpha; + SR_RTensor tSR_rBias; + SR_RTensor tSR_rMean; + SR_RTensor tSR_rInvStddev; + SR_STensor tSR_sAlpha; + SR_STensor tSR_sBias; + SR_STensor tSR_sMean; + SR_STensor tSR_sInvStddev; + SR_CTensor tSR_cAlpha; + SR_SCTensor tSR_sColAlpha; + SR_SCTensor tSR_sColBias; + + ThrNum thr_num; + + RTensor tCrAlpha; // (CPY,CPY_M,CPY_N) + RTensor tCrBias; // (CPY,CPY_M,CPY_N) + + STensor tCsAlpha; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + STensor tCsBias; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + + Params const* params_ptr; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if (epi_m == 0 && epi_n == 0) { // Assumes M-major subtile loop + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + int pipe_index = (load_iteration / EpiTiles) % Stages; + + Tensor tSR_rAlpha_flt = filter_zeros(tSR_rAlpha); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_rMean_flt = filter_zeros(tSR_rMean); + Tensor tSR_rInvStddev_flt = filter_zeros(tSR_rInvStddev); + Tensor tSR_sAlpha_flt = filter_zeros(tSR_sAlpha(_,_,_,pipe_index)); + Tensor tSR_sBias_flt = filter_zeros(tSR_sBias(_,_,_,pipe_index)); + Tensor tSR_sMean_flt = filter_zeros(tSR_sMean(_,_,_,pipe_index)); + Tensor tSR_sInvStddev_flt = filter_zeros(tSR_sInvStddev(_,_,_,pipe_index)); + Tensor tSR_cAlpha_flt = filter_zeros(tSR_cAlpha, tSR_rAlpha.stride()); + + for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { + if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + // OOB of SMEM + continue; + } + tSR_rAlpha_flt(i) = tSR_sAlpha_flt(i); + tSR_rBias_flt(i) = tSR_sBias_flt(i); + tSR_rMean_flt(i) = tSR_sMean_flt(i); + tSR_rInvStddev_flt(i) = tSR_sInvStddev_flt(i); + } + + constexpr int RegFragSize = cute::min(size(tSR_rAlpha_flt), cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute)))); + Tensor tSR_rAlpha_frg = recast>(tSR_rAlpha_flt); // (FRG_V) + Tensor tSR_rBias_frg = recast>(tSR_rBias_flt); // (FRG_V) + Tensor tSR_rMean_frg = recast>(tSR_rMean_flt); // (FRG_V) + Tensor tSR_rInvStddev_frg = recast>(tSR_rInvStddev_flt); // (FRG_V) + + cutlass::multiplies> mul; + cutlass::negate> negate; + cutlass::multiply_add> mul_add; + + // We do computation among vectors before computation among matrices + // alpha' = alpha * inv_stddev + // bias' = bias - alpha' * mean + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAlpha_frg); ++i) { + tSR_rAlpha_frg(i) = mul(tSR_rAlpha_frg(i), tSR_rInvStddev_frg(i)); + tSR_rBias_frg(i) = mul_add(tSR_rAlpha_frg(i), negate(tSR_rMean_frg(i)), tSR_rBias_frg(i)); + } + + Tensor tSR_sColAlpha_flt = filter_zeros(tSR_sColAlpha(_,_,_,pipe_index)); + Tensor tSR_sColBias_flt = filter_zeros(tSR_sColBias(_,_,_,pipe_index)); + // After computation, 4 vectors -> 2 vectors + for (int i = 0; i < size(tSR_rAlpha_flt); ++i) { + if (get<1>(tSR_cAlpha_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + // OOB of SMEM + continue; + } + tSR_sColAlpha_flt(i) = tSR_rAlpha_flt(i); + tSR_sColBias_flt(i) = tSR_rBias_flt(i); + } + + synchronize(); + + // To do bn_apply with Acc, reload these 2 vectors with the consistent shape + copy_aligned(tCsAlpha(_,_,_,_,_,pipe_index), tCrAlpha); + copy_aligned(tCsBias(_,_,_,_,_,pipe_index), tCrBias); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_inputs) { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + cutlass::multiply_add> mul_add; + + Array frg_apply; + + using ConvertInput = NumericArrayConverter; + using ConvertOutput = NumericArrayConverter; + + ConvertInput convert_input{}; + ConvertOutput convert_output{}; + + Array frg_I = convert_input(frg_inputs); + + Tensor tCrAlpha_frg = recast>(tCrAlpha(_,_,_,epi_m,epi_n)); + Tensor tCrBias_frg = recast>(tCrBias(_,_,_,epi_m,epi_n)); + + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + using RegFragArr = Array, RegFragArraySize>; + RegFragArr& frg_I_ = reinterpret_cast(frg_I); + RegFragArr& frg_apply_ = reinterpret_cast(frg_apply); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + frg_apply_[i] = mul_add(tCrAlpha_frg(epi_v * RegFragArraySize + i), frg_I_[i], tCrBias_frg(epi_v * RegFragArraySize + i)); + } + + return convert_output(frg_apply); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor sAlpha = make_tensor(make_smem_ptr(smem_alpha), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sBias = make_tensor(make_smem_ptr(smem_bias), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sColAlpha = make_tensor(make_smem_ptr(smem_col_alpha), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sColBias = make_tensor(make_smem_ptr(smem_col_bias), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sMean = make_tensor(make_smem_ptr(smem_mean), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor sInvStddev = make_tensor(make_smem_ptr(smem_inv_stddev), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + + // S2R: Smem to Reg + auto tiled_s2r = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_s2r = tiled_s2r.get_slice(args.thread_idx); + Tensor tSR_sAlpha = thr_s2r.partition_S(sAlpha); + Tensor tSR_sBias = thr_s2r.partition_S(sBias); + Tensor tSR_sMean = thr_s2r.partition_S(sMean); + Tensor tSR_sInvStddev = thr_s2r.partition_S(sInvStddev); + Tensor tSR_sColAlpha = thr_s2r.partition_S(sColAlpha); + Tensor tSR_sColBias = thr_s2r.partition_S(sColBias); + Tensor tSR_cAlpha = thr_s2r.partition_S(args.cD); + + Tensor tSR_rAlpha = make_tensor_like(take<0,3>(tSR_sAlpha)); // need to check + Tensor tSR_rBias = make_tensor_like(take<0,3>(tSR_sBias)); + Tensor tSR_rMean = make_tensor_like(take<0,3>(tSR_sMean)); + Tensor tSR_rInvStddev = make_tensor_like(take<0,3>(tSR_sInvStddev)); + + Tensor tCsAlpha = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sColAlpha, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCsBias = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sColBias, args.epi_tile, args.tiled_copy, args.thread_idx); + + Tensor tCrAlpha = make_tensor_like(take<0,5>(tCsAlpha)); // (CPY,CPY_M,CPY_N) + Tensor tCrBias = make_tensor_like(take<0,5>(tCsBias)); // (CPY,CPY_M,CPY_N) + + constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value; + return ConsumerStoreCallbacks( + cute::move(tSR_rAlpha), cute::move(tSR_rBias), + cute::move(tSR_rMean), cute::move(tSR_rInvStddev), + cute::move(tSR_sAlpha), cute::move(tSR_sBias), + cute::move(tSR_sMean), cute::move(tSR_sInvStddev), + cute::move(tSR_cAlpha), + cute::move(tSR_sColAlpha), cute::move(tSR_sColBias), + cute::move(tCrAlpha), cute::move(tCrBias), + cute::move(tCsAlpha), cute::move(tCsBias), + ThreadCount{}, + params_ptr); + } +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d026b15ccacef0bb199b7a98172c722f9402d075 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm100_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm100 TMA warp-specialized (ws) epilogue +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/helper_macros.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +namespace detail { + template + CUTLASS_DEVICE auto + compute_quantized_with_row_scalefactor( + Array& frg_compute, + Array& frg_sf, + ElementCompute norm_constant) + { + cutlass::multiplies mul; + cutlass::multiplies> mul_array; + + Array frg_output; + auto output_frgs = reinterpret_cast *>(frg_output.data()); + auto compute_frgs = reinterpret_cast *>(frg_compute.data()); + + Array qpvscale_rcps = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (cute::is_same_v) { + // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate>{}(frg_sf); + return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcp); + } + else { + // UE4M3: Do the rcp in fp32 data type. + auto qpvscale_ups = cutlass::NumericArrayConverter{}(frg_sf); + return cutlass::reciprocal_approximate_ftz{}(qpvscale_ups); + } + }(); + + // norm_constant and qpvscale_rcps are all positive numbers. + auto acc_scales = cutlass::multiplies>{}(norm_constant, qpvscale_rcps); + + CUTLASS_PRAGMA_UNROLL + for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { + // Map INF to fp32::max + auto acc_scale = minimum_with_nan_propagation{}(acc_scales[sf_v], cutlass::platform::numeric_limits::max()); + // Convert to output type + output_frgs[sf_v] = cutlass::NumericArrayConverter{}(mul_array(compute_frgs[sf_v], acc_scale)); + } + return frg_output; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// BlockScaleFactor Generation Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int SFVecSize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm100BlockScaleFactorRowStore { + static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); + static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or + size<1>(EpilogueTile{}) / SFVecSize == 2 or + size<1>(EpilogueTile{}) / SFVecSize == 4 or + size<1>(EpilogueTile{}) / SFVecSize == 8, + "Possible store in interleaved 4B aligned format"); + using NormalConstStrideMNL = Stride<_0,_0,int64_t>; + struct SharedStorage { }; + + struct Arguments { + ElementBlockScaleFactor* ptr_scale_factor = nullptr; + ElementCompute const* norm_constant_ptr = nullptr; + NormalConstStrideMNL norm_constant_stride = {}; + }; + + using Params = Arguments; + + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = (N % SFVecSize == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); + } + return implementable; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorRowStore() { } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class GTensor, + class CoordGTensor, + class ThrResidue, + class EpiTileCoordMN, + class ElementType + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) + GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + CoordGTensor tC_cSFD_, // (m,n) + ThrResidue residue_tC_cSFD_, // (m,n) + Params const* params_ptr_, + EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) + ElementType norm_constant_, + ElementType norm_constant_scaled_down_) + : tC_rSFD(cute::forward(tC_rSFD_)) + , tC_gSFD(cute::forward(tC_gSFD_)) + , tC_cSFD(tC_cSFD_) + , residue_tC_cSFD(residue_tC_cSFD_) + , params_ptr(params_ptr_) + , norm_constant(norm_constant_) + , norm_constant_scaled_down(norm_constant_scaled_down_) + , epi_tile_coord_mn(epi_tile_coord_mn_){} + + static_assert(is_same_v); + RTensor tC_rSFD; + GTensor tC_gSFD; + CoordGTensor tC_cSFD; + ThrResidue residue_tC_cSFD; + Params const* params_ptr; + ElementCompute norm_constant; + ElementCompute norm_constant_scaled_down; + EpiTileCoordMN epi_tile_coord_mn; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, + int epi_m, + int epi_n, + Array const& frg_input) + { + static_assert(FragmentSize % SFVecSize == 0, "Scale factor vector size should divide FragmentSize"); + constexpr int NumVecs = FragmentSize / SFVecSize; + Array frg_compute; + + auto input_frgs = reinterpret_cast const*>(frg_input.data()); + auto compute_frgs = reinterpret_cast *>(frg_compute.data()); + + Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) + + cutlass::multiplies mul; + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // SF generation + CUTLASS_PRAGMA_UNROLL + for (int sf_v = 0; sf_v < NumVecs; ++sf_v) { + compute_frgs[sf_v] = NumericArrayConverter{}(input_frgs[sf_v]); + /// Step1: get max across a vector + vec_maxs[sf_v] = amax_reduction(ElementCompute(0), compute_frgs[sf_v]); + } + + /// Step2: Compute Scale + pvscales = cutlass::multiplies>{}(vec_maxs, norm_constant_scaled_down); + + tC_rSFD_frg(_0{}) = cutlass::NumericArrayConverter{}(pvscales); + + Tensor tCgSFD_flt = filter_zeros(tC_gSFD(_,_,_,_0{},_0{},get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n)); + Tensor tCrSFD_flt = filter_zeros(tC_rSFD); + constexpr auto MCL = decltype(max_common_layout(tCgSFD_flt, tCrSFD_flt)){}; + constexpr int V = cute::min(4, size(MCL)); + using VecType = uint_bit_t>; + Tensor tCgSFD_vec = recast(coalesce(tCgSFD_flt)); + Tensor tCrSFD_vec = recast(coalesce(tCrSFD_flt)); + Tensor tCcSFD_pred = tC_cSFD(_,_,_, epi_m, epi_n); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSFD_vec); i++){ + if (elem_less(tCcSFD_pred(i * SFVecSize * V), residue_tC_cSFD)) { + tCgSFD_vec(i) = tCrSFD_vec(i); + } + } + /// Step3: Compute quantized output values + return detail::compute_quantized_with_row_scalefactor(frg_compute, tC_rSFD_frg(_0{}), norm_constant); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; + tile_coord_l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } + + auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); + Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) + Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + + auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); + + // Fetch and compute these during initialization + Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); + ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); + ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); + ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); + ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); +#if 0 + if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ + print("epi_tile ");print(args.epi_tile); print("\n"); + print("mSFD ");print(mSFD); print("\n"); + print("gSFD ");print(gSFD); print("\n"); + print("tCgSFD ");print(tCgSFD); print("\n"); + print("tCrSFD ");print(tCrSFD); print("\n"); + print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); + print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); + } +#endif + + return ConsumerStoreCallbacks( + cute::move(tCrSFD), + cute::move(tCgSFD), + args.tCcD, + args.residue_tCcD, + params_ptr, + epi_tile_coord_mn, + norm_constant, + norm_constant_scaled_down); + + } +}; + +template < + int SFVecSize, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm100BlockScaleFactorColStore { + + static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); + static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or + size<0>(EpilogueTile{}) / SFVecSize == 2 or + size<0>(EpilogueTile{}) / SFVecSize == 4 or + size<0>(EpilogueTile{}) / SFVecSize == 8, + "Possible store in interleaved 4B aligned format"); + using NormalConstStrideMNL = Stride<_0,_0,int64_t>; + static constexpr int NumSyncWarps = SFVecSize == 64 ? 4 : 0; + static constexpr int NumSyncThreads = NumSyncWarps * NumThreadsPerWarp; + struct SharedStorage { + array_aligned smem_aux; + }; + + struct Arguments { + ElementBlockScaleFactor* ptr_scale_factor = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + ElementCompute const* norm_constant_ptr = nullptr; + NormalConstStrideMNL norm_constant_stride = {}; + }; + + using Params = Arguments; + + // BlockScaleFactor generation is per batch or group + // For Ptr-Array GEMM and Grouped GEMM, ElementBlockScaleFactor is ElementType* + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = (M % SFVecSize == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm100BlockScaleFactorColStore] M-dim should be divisible by SFVecSize.\n"); + } + return implementable; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorColStore() { } + + CUTLASS_HOST_DEVICE + Sm100BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) + , smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr = nullptr; + ElementCompute *smem_aux = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class GTensor, + class STensor, + class CoordGTensor, + class ThrResidue, + class EpiTileCoordMN, + class ElementType + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + // Normally, we should use tile_shape_mnk to tile the gtensor. + // However, the SF gtensor could not be divisible by non-pow2 cta tile, so we use epi tile (pow2) to do tiling. + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rSFD_, // (CPY,CPY_M,CPY_N) + GTensor&& tC_gSFD_, // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + STensor&& sAmaxs_, // (NumSyncWarps) + CoordGTensor tC_cSFD_, // (m,n) + ThrResidue residue_tC_cSFD_, // (m,n) + Params const* params_ptr_, + EpiTileCoordMN epi_tile_coord_mn_, // (epi_tile_coord_m, epi_tile_coord_n) + ElementType norm_constant_, + ElementType norm_constant_scaled_down_) + : tC_rSFD(cute::forward(tC_rSFD_)) + , tC_gSFD(cute::forward(tC_gSFD_)) + , sAmaxs(cute::forward(sAmaxs_)) + , tC_cSFD(tC_cSFD_) + , residue_tC_cSFD(residue_tC_cSFD_) + , params_ptr(params_ptr_) + , norm_constant(norm_constant_) + , norm_constant_scaled_down(norm_constant_scaled_down_) + , epi_tile_coord_mn(epi_tile_coord_mn_) {} + + static_assert(is_same_v); + RTensor tC_rSFD; + GTensor tC_gSFD; + STensor sAmaxs; + CoordGTensor tC_cSFD; + ThrResidue residue_tC_cSFD; + Params const* params_ptr; + ElementCompute norm_constant; + ElementCompute norm_constant_scaled_down; + EpiTileCoordMN epi_tile_coord_mn; + + CUTLASS_DEVICE + ElementCompute find_amax(ElementCompute max) { + // Overall idea: after TMEM_LOAD.32DP32bit pattern, each thread in the warp can load adjacent elements of a column into its private RF. + // Here we are using shuffle instructons to the amax value of the adjacent column elements. + // For VS16, t0~t15 would generate an amax, and t16~t31 would generate another one. + // For VS32, t0~t31 should generate an amax. + // For VS64, t0~t63 should generate an amax. We would first do the reduciton within a warp, + // and then use smem to do inter-warp reduction. + if constexpr (SFVecSize == 32) { + return cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); + } + else if constexpr (SFVecSize == 16) { + return cutlass::redux_abs_max_nan_propagation_sync_warp_t0t15_t16t31{}(max); + } + else if constexpr (SFVecSize == 64) { + // Get abs_max per warp + auto abs_max = cutlass::redux_abs_max_nan_propagation_sync_warp{}(max); + + // Switch the amax of adjacent warps + const bool leading_thread = (threadIdx.x % NumThreadsPerWarp) == 0; + const int warp_idx = threadIdx.x / NumThreadsPerWarp % 4; + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(NumSyncThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + // Inter-warp reduction for VS=64 + // Only 4 * FP32 = 16 bytes smem is needed as we have 4 warps. + if (leading_thread) { + sAmaxs(warp_idx) = abs_max; + } + synchronize(); + // Switch data between two adjacent warps to do reduction + float tmp = sAmaxs(warp_idx^1); + synchronize(); + abs_max = cutlass::maximum_with_nan_propagation{}(abs_max,tmp); + return abs_max; + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported VecSize"); + } + } + + template + CUTLASS_DEVICE auto + compute_quantized_value(Array compute, Array sf) { + cutlass::multiplies> mul_array; + auto qpvscale_rcp = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (cute::is_same_v) { + // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. + auto e8m0_qpvscale_rcps = cutlass::reciprocal_approximate>{}(sf); + return cutlass::NumericArrayConverter{}(e8m0_qpvscale_rcps); + } + else { + // UE4M3: Do the rcp in fp32 data type. + auto qpvscale_up = cutlass::NumericArrayConverter{}(sf); + return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); + } + }(); + // norm_constant and qpvscale_rcps[sf_v] are all positive numbers. + auto acc_scale = mul_array(norm_constant, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + return mul_array(compute, acc_scale); + } + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, + int epi_m, + int epi_n, + Array const& frg_input) + { + constexpr int NumVecs = 1; // each thread only compute 1 col scalefactors + Array frg_compute; + Array frg_output; + Array frg_scale_float; + Array frg_amax; + Array frg_scale; + + Tensor tC_rSFD_frg = recast>(coalesce(filter(tC_rSFD))); // (EPI_V) + + cutlass::multiplies mul; + cutlass::multiplies> mul_array; + /// convert acc to Element Compute + auto compute_frgs = NumericArrayConverter{}(frg_input); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + /// Step1: get max across a vector + frg_amax[i] = find_amax(compute_frgs[i]); + } + + frg_scale_float = mul_array(frg_amax, norm_constant_scaled_down); + frg_scale = cutlass::NumericArrayConverter{}(frg_scale_float); + auto tC_cSFD_pred = tC_cSFD(_,_,_,epi_m,epi_n); + auto tC_gSFD_store = tC_gSFD(_,_,_,_,_,get<0>(epi_tile_coord_mn) + epi_m, get<1>(epi_tile_coord_mn) + epi_n); + for (int i=0; i < cute::ceil_div(FragmentSize, SFVecSize); i++) { + int idx = i * SFVecSize + threadIdx.x % SFVecSize; + if (idx < FragmentSize && elem_less(tC_cSFD_pred(idx), residue_tC_cSFD)) { + UnderlyingElementBlockScaleFactor tmp = frg_scale[idx]; + // Store the (EpilogueTile / SFVecSize) elements. + tC_gSFD_store(idx) = tmp; + } + } + + /// Step3: Compute quantized output values + if constexpr (cute::sizeof_bits_v == 4) { + return compute_quantized_value(compute_frgs, frg_scale); // ElementCompute + } + else { + // 6bits or 8bits output. + compute_frgs = compute_quantized_value(compute_frgs, frg_scale); + frg_output = cutlass::NumericArrayConverter{}(compute_frgs); + return frg_output; // ElementOutput + } + + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [tile_coord_m, tile_coord_n, tile_coord_k, tile_coord_l] = args.tile_coord_mnkl; + using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[tile_coord_l]; + tile_coord_l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } + + auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + //Tensor gSFD = local_tile(mSFD, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); + // Normally, we should use tile_shape_mnk to tile the mSFD tensor. However, we could not do it for non-pow2 cta tile with vectorsize = 32. + // For scale factor, 128x4 elements are stored in a basic block, and the layout of mSFD is ((_32,_4,int),(_32,_4,int),int):((_16,_4,int),(_0,_1, int),int) + // If we tiled it using tile_shape_mnk(128, 192), the N mode would encounter shape_div failure because (32, 4) could not be divisible by 192. + // Therefore, switching to using pow2 epilogue tile. + static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); + Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_,_,tile_coord_l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) + Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + + auto epi_tile_coord_mn = make_coord(tile_coord_m * size<0>(epi_tile_mn), tile_coord_n * size<1>(epi_tile_mn)); + + // Fetch and compute these during initialization + Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); + ElementCompute norm_constant = mNormConst(_0{},_0{},tile_coord_l); + ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); + ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); + ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); + + Tensor sAmaxs = make_tensor(make_smem_ptr(smem_aux), make_layout(_4{})); +#if 0 + if(threadIdx.x == 128 && blockIdx.x == 0 && blockIdx.y == 0){ + print("mSFD ");print(mSFD); print("\n"); + print("gSFD ");print(gSFD); print("\n"); + print("tCgSFD ");print(tCgSFD); print("\n"); + print("tCrSFD ");print(tCrSFD); print("\n"); + print("args.tCcD ");print(args.tCcD); print("\n"); + print("args.residue_tCcD ");print(args.residue_tCcD); print("\n"); + print("filter(tCrSFD) ");print(filter(tCrSFD)); print("\n"); + print("filter(tCgSFD) ");print(filter(tCgSFD)); print("\n"); + } +#endif + + return ConsumerStoreCallbacks( + cute::move(tCrSFD), + cute::move(tCgSFD), + cute::move(sAmaxs), + args.tCcD, + args.residue_tCcD, + params_ptr, + epi_tile_coord_mn, + norm_constant, + norm_constant_scaled_down); + } +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b769b1f0fbe2aa78f0ee97da442fb61c1aa49cc8 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,1593 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Fusion callbacks specializations for the SM120 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Sm120 Tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// D = alpha * acc + beta * C +// With BlockScaleFactor Generation. +// 1. Find max of 32 F32 elements +// 2. Convert the max to UE8 (or UE4M3) and store the result. +// 3. Convert the UE8 (or UE4M3) back to F32 scale. +// 4. Reciprocal of F32 scale with MUFU. +// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinearCombRowBlockScaleFactor = + Sm90EVT, // gen scalefactor + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm120LinearCombRowBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + + using Sm100Fusion = FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile + >; + using Operation = typename Sm100Fusion::Operation; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = alpha * acc + beta * C + per-row bias +// with row blockScaled generation +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerRowBiasRowBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, // gen scalefactor + Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm120LinCombPerRowBiasRowBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// with row blockScaled generation +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerRowBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, // gen scalefactor + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm120LinCombPerRowBiasEltActRowBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = alpha * acc + beta * C + per_col bias +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerColBiasRowBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, // gen scalefactor + Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm120LinCombPerColBiasRowBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = activation(alpha * acc + beta * C + per_col bias) +// with row blockScaled generation +template< + int StagesC, + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerColBiasEltActRowBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, // gen scalefactor + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm120LinCombPerColBiasEltActRowBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +// with per column blockScaled generation +// 1. Find max of 32 F32 elements +// 2. Convert the max to UE8 (or UE4M3) and store the result. +// 3. Convert the UE8 (or UE4M3) back to F32 scale. +// 4. Reciprocal of F32 scale with MUFU. +// 5. Multiply each F32 element with the above reciprocal, then convert to ElementD +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinearCombColBlockScaleFactor = Sm90EVT< + Sm120BlockScaleFactorColStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle>, + Sm90LinearCombination< + ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle> + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized< + StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>, + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute,ElementBlockScaleFactor, + cutlass::layout::ColumnMajor, ElementSource, ElementScalar, RoundStyle>, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinearCombColBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = Sm120LinearCombColBlockScaleFactor::type,ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle>; + + using Sm100Fusion = FusionCallbacks< + epilogue::Sm100TmaWarpSpecialized, + fusion::LinCombBlockScaleFactor, + CtaTileShapeMNK, + EpilogueTile + >; + using Operation = typename Sm100Fusion::Operation; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { + { + // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = alpha * acc + beta * C + per-Col bias +// with per column blockScaled generation +template< + int StagesC, + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerColBiasColBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorColStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerColBiasColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm120LinCombPerColBiasColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = activation(alpha * acc + beta * C + per_col bias) +// with per column blockScaled generation +template< + int StagesC, + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerColBiasEltActColBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorColStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerColBiasEltActColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + using Impl = + Sm120LinCombPerColBiasEltActColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerColBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, + ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// with per column blockScaled generation +template< + int StagesC, + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerRowBiasEltActColBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorColStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, + ElementCompute, ElementCompute, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerRowBiasEltActColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > { + + + using Impl = + Sm120LinCombPerRowBiasEltActColBlockScaleFactor< + StagesC, SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias,ElementSource, ElementScalar, + AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +// D = alpha * acc + beta * C + per-row bias +// with per column blockScaled generation +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombPerRowBiasColBlockScaleFactor = + Sm90EVT< + Sm120BlockScaleFactorColStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor, RoundStyle + >, // gen scalefactor + Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementCompute, ElementCompute, + ElementBias, ElementSource, ElementScalar, + AlignmentBias, RoundStyle + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120TmaWarpSpecialized, + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombPerRowBiasColBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + > +{ + + using Impl = + Sm120LinCombPerRowBiasColBlockScaleFactor< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementBias, + ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + using Operation = + fusion::LinCombPerRowBiasBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::ColumnMajor, + ElementBias, ElementSource, ElementScalar,AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBlockScaleFactor * block_scale_factor_ptr = nullptr; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +// Sm120 Ptr array tma warp specialized callbacks just alias to their sm90 counterpart +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... +> : FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args... + > { + using FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>::FusionCallbacks; +}; + +// For Ptr-Array and Grouped GEMM +// D = alpha * acc + beta * C, where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinearCombRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinearCombinationPtrArray< ElementCompute, ElementCompute, + ElementSource, ElementScalar, RoundStyle + > // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinearCombRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombBlockScaleFactor< + SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + + operator typename Impl::Arguments() const { + return + { + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + +// For Ptr-Array and Grouped GEMM +// D = activation(alpha * acc + beta * C), where alpha and beta can be vectors for each batch/group +// With Row BlockScaleFactor Generation, separate tensors per batch/group. +template< + int SFVecsize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm120LinCombEltActRowBlockScaleFactorPtrArray = + Sm90EVT< + Sm120BlockScaleFactorRowStore< + SFVecsize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ElementOutput, + ElementCompute, ElementBlockScaleFactor *, RoundStyle + >, // gen scalefactor + Sm90LinCombEltActPtrArray // activation(beta * C + (alpha * acc)) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + int SFVecSize, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm120PtrArrayTmaWarpSpecialized, + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + > { + + using Impl = + Sm120LinCombEltActRowBlockScaleFactorPtrArray< + SFVecSize, EpilogueTile, CtaTileShapeMNK, FragmentSize, ActivationFn, + typename cutlass::detail::get_unpacked_element_type::type, + ElementCompute, ElementBlockScaleFactor, ElementSource, ElementScalar, RoundStyle + >; + + using Operation = + fusion::LinCombEltActBlockScaleFactor< + ActivationFn, SFVecSize, ElementOutput, ElementCompute, + ElementBlockScaleFactor, cutlass::layout::RowMajor, + ElementSource, ElementScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + ElementBlockScaleFactor ** block_scale_factor_ptr = nullptr; + + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + using StrideNormConst = Stride<_0,_0,int64_t>; + ElementCompute const* norm_constant_ptr = nullptr; + StrideNormConst dNormConst = {_0{}, _0{}, 0}; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {block_scale_factor_ptr, norm_constant_ptr, dNormConst} // BlockScaleFactor args + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e72e971bd8d99f87a2528af3c1dbd27366298ef5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm120_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,899 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief Visitor tree store operations for the SM120 TMA warp-specialized (ws) epilogue +*/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// BlockScaleFactor Generation Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int SFVecSize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm120BlockScaleFactorRowStore { + + static_assert(size<1>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); + static_assert(size<1>(EpilogueTile{}) / SFVecSize == 1 or + size<1>(EpilogueTile{}) / SFVecSize == 2 or + size<1>(EpilogueTile{}) / SFVecSize == 4 or + size<1>(EpilogueTile{}) / SFVecSize == 8, + "Possible store in interleaved 4B aligned format"); + + static constexpr int NumWarpgroups = 2; + static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; + static constexpr int NumQuadsPerWarp = 8; + static constexpr int NumSyncQuads = NumSyncWarps * NumQuadsPerWarp; + struct SharedStorage { + array_aligned smem_aux; + }; + using NormalConstStrideMNL = Stride<_0,_0,int64_t>; + struct Arguments { + ElementBlockScaleFactor* ptr_scale_factor = {}; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + ElementCompute const* norm_constant_ptr = {}; + NormalConstStrideMNL norm_constant_stride = {}; + }; + + using Params = Arguments; + + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = (N % SFVecSize == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorRowStore] N-dim should be divisible by SFVecSize.\n"); + } + return implementable; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm120BlockScaleFactorRowStore() { } + + CUTLASS_HOST_DEVICE + Sm120BlockScaleFactorRowStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) + , smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr = nullptr; + ElementCompute *smem_aux = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class GTensor, + class STensor, + class CoordGTensor, + class ThrResidue, + class TileCoordMN, + class ElementType, + class TiledCopy_ + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rSFD_, + GTensor&& tC_gSFD_, + STensor&& sAmaxs_, + CoordGTensor tC_cSFD_, + ThrResidue residue_tC_cSFD_, + Params const* params_ptr_, + TileCoordMN tile_coord_mn_, + ElementType norm_constant_, + ElementType norm_constant_scaled_down_, + int thread_idx_, + TiledCopy_ const&) + : tC_rSFD(cute::forward(tC_rSFD_)) + , tC_gSFD(cute::forward(tC_gSFD_)) + , sAmaxs(cute::forward(sAmaxs_)) + , tC_cSFD(tC_cSFD_) + , residue_tC_cSFD(residue_tC_cSFD_) + , params_ptr(params_ptr_) + , norm_constant(norm_constant_) + , norm_constant_scaled_down(norm_constant_scaled_down_) + , tile_coord_mn(tile_coord_mn_) + , thread_idx(thread_idx_) {} + + static_assert(is_same_v); + RTensor tC_rSFD; + GTensor tC_gSFD; + STensor sAmaxs; + CoordGTensor tC_cSFD; + ThrResidue residue_tC_cSFD; + Params const* params_ptr; + ElementCompute norm_constant; + ElementCompute norm_constant_scaled_down; + TileCoordMN tile_coord_mn; + int thread_idx; + static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; + static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); + static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; + static_assert(NumCollaboratingWarpGroups == 1 || NumCollaboratingWarpGroups == 2, + "SM120 epilogue currently only supports one or two warp groups collaborating."); + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, + int epi_m, + int epi_n, + Array const& frg_input) { + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + /* + Accumulator fragments are distributed across quads in different warps. + For SFVector = 16, we have: + + 8 elements 8 elements 8 elements 8 elements + <----------------><-----------------><-----------------><-----------------> + Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 + Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 + Warp 0 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 Warp 4 Quad 0 + Warp 0 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 Warp 4 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 Warp 4 Quad 7 + + + + + + In this case, row-wise scale factors are cooperatively reduced across 4 + threads from 1 quad in 1 warp. Each quad computes its own, local absolute + maximum without communicating with other warps through shared memory. + + For SFVector = 32, we have: + 8 elements 8 elements 8 elements 8 elements + <----------------><-----------------><-----------------><-----------------> + Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 + Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 + Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 + Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 + + + + + + For SFVector = 64, we have: + 8 elements 8 elements 8 elements 8 elements + <----------------><-----------------><-----------------><-----------------> + Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 + Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 + Warp 0 Quad 0 Warp 2 Quad 0 Warp 4 Quad 0 Warp 6 Quad 0 + Warp 0 Quad 1 Warp 2 Quad 1 Warp 4 Quad 1 Warp 6 Quad 1 + ... ... ... ... + Warp 0 Quad 7 Warp 2 Quad 7 Warp 4 Quad 7 Warp 6 Quad 7 + + + + Thus, rowwise scale factors are cooperatively reduced across 8 threads + from two quads in two warps. Each quad first computes its own, local + absolute maximum and then shares this with the corresponding quad in the + other warp. In this case, a reduction through shared memory is needed. + + For a non-cooperative epilogue (in which each warpgroup computes a + separate tile), the pattern is the same as that above, except that warps 0 + and 2 are in the same row, and 1 and 3 are in the same row, and warps 4-7 + are not included. + */ + + // Accumulator fragments consist of two elements from two different rows of a 16x8 MMA output + static constexpr int ColsPerThreadAccFrag = 2; + static constexpr int RowsPerThreadAccFrag = 2; + static_assert(FragmentSize == + (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); + + static constexpr int NumThreadsPerQuad = 4; + static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); + // A quad from two or four warps participate in computing each scale factor. + constexpr int WarpsPerSF = SFVecSize / 16; + static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); + + constexpr bool IsInterWarpReductionNeeded = (WarpsPerSF != 1); + + // Number of fragments for each thread that are needed for computing a scale factor + static constexpr int AccFragsPerSF = SFVecSize / (ColsPerThreadAccFrag * NumThreadsPerQuad * WarpsPerSF); + static_assert(size<2>(visit_results) % AccFragsPerSF == 0, + "Fragments along N mode must be a multiple of the number of accumulator fragments needed per SF"); + + auto warp_idx = thread_idx / NumThreadsPerWarp; + auto warpgroup_idx = thread_idx / NumThreadsPerWarpGroup; + auto quad_idx_in_warp = (thread_idx % NumThreadsPerWarp) / NumThreadsPerQuad; + auto thread_idx_in_quad = thread_idx % NumThreadsPerQuad; + + cutlass::maximum_absolute_value_reduction amax_op; + cutlass::multiplies mul; + + Tensor tC_rSFD_flt = filter_zeros(tC_rSFD); + + auto synchronize = [&] () { + cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + }; + + CUTLASS_PRAGMA_UNROLL + for (int sf_id = 0; sf_id < size(tC_rSFD_flt); ++sf_id) { + + auto coord = idx2crd(sf_id, tC_rSFD_flt.shape()); + auto row_in_acc = get<0,1,1>(coord); + auto row = crd2idx(get<1>(coord), get<1>(tC_rSFD_flt.shape())); + auto sf = crd2idx(get<2>(coord), get<2>(tC_rSFD_flt.shape())); + + // + // Compute amax for this scale factor + // + ElementCompute amax{0}; + + // Compute amax among vals owned by this thread for this vector + auto acc_frag_row = row_in_acc * RowsPerThreadAccFrag; + auto acc_frag_start_for_sf = sf * AccFragsPerSF; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccFragsPerSF; ++i) { + auto acc_frg = visit_results(0, row, acc_frag_start_for_sf + i); + amax = amax_op(amax, acc_frg[acc_frag_row]); + amax = amax_op(amax, acc_frg[acc_frag_row + 1]); + } + + // At this point, each thread has computed the amax of the values that it owns for this SF vector. + // We now need to compute the amax across threads. Because the TiledMMA uses an MmaThrLayout of <4,1,1>, + // we know that all fragments in this row will belong to threads in this warp. Furthermore, because + // SM120 narrow-precision MMAs have 16x8 output size with a quad owning two rows, we know that a quad + // will own all of the elements to be reduced via amax. Therefore, we can use warp shuffle intrinsics + // among threads in one quad to compute the amax. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < 3; ++i) { + auto amax_other = __shfl_xor_sync(0xffffffff, amax, i); + amax = amax_op(amax, amax_other); + } + + if constexpr (IsInterWarpReductionNeeded) { + // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its quad + // that should be used in computing the amax for this SF. Threads 0 in each quad of warps 0 and 2 + // (similarly, 1 and 3) now exchange amaxes to compute the final amax. + if (thread_idx_in_quad == 0) { + sAmaxs(quad_idx_in_warp, warp_idx) = amax; + } + synchronize(); + + // Get the amax broadcasted by the warp with which we share. + // Work on 4 warps per SFD generation + if constexpr (WarpsPerSF == 4) { + if constexpr (NumCollaboratingWarpGroups == 2) { + // This implementation assumes warp layout 2 x 4. + // For cooperative kernels (NumCollaboratingWarpGroups=2), + // warp 0 shares with 2 / 4 / 6, warp 1 shares with 3 / 5/ 7. + auto amax_other2 = sAmaxs(quad_idx_in_warp, warp_idx ^ 2); + auto amax_other4 = sAmaxs(quad_idx_in_warp, warp_idx ^ 4); + auto amax_other6 = sAmaxs(quad_idx_in_warp, warp_idx ^ 6); + synchronize(); + amax = amax_op(amax, amax_other2); + amax = amax_op(amax, amax_other4); + amax = amax_op(amax, amax_other6); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported warp layout."); + } + } + // Work on 2 warps per SFD generation + else if constexpr(WarpsPerSF == 2) { + // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares + // with 4, 1 shares with 5, etc. For non-cooperative kernels + // (NumCollaboratingWarpGroups=1), 0 shares with 2, 1 shares with 3. + auto amax_other = sAmaxs( + quad_idx_in_warp, warp_idx ^ (1 << NumCollaboratingWarpGroups)); + synchronize(); + amax = amax_op(amax, amax_other); + } + } + + ElementCompute pvscale = mul(amax, norm_constant_scaled_down); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + tC_rSFD_flt(coord) = qpvscale; + + // + // Apply the scale factor to the output + // + ElementCompute qpvscale_rcp = [&]() { + if constexpr (cute::is_same_v) { + // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + } + else { + // UE4M3: Do the rcp in fp32 data type. + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); + } + }(); + + ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + + // Compute quantized output values + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccFragsPerSF; ++i) { + auto acc_frag = visit_results(0, row, acc_frag_start_for_sf + i); + visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row ] = mul(acc_frag[acc_frag_row], acc_scale); + visit_results(0, row, acc_frag_start_for_sf + i)[acc_frag_row + 1] = mul(acc_frag[acc_frag_row + 1], acc_scale); + } + } // sf + + // Since scale factors are computed cooperatively across two quads from two warps, we only need one thread from the + // set of 8 cooperating threads to write out the data. We do this with thread 0 in each quad of the first warp that collaborates. + bool write_sf = (thread_idx_in_quad == 0); + if constexpr (NumCollaboratingWarpGroups == 2) { + // For cooperative kernels (NumCollaboratingWarpGroups=2), 0 shares with 4, 1 shares with 5, etc. + // Thus, only the warps in the first warpgroup need to write out scale factors. + if constexpr (IsInterWarpReductionNeeded) { + write_sf &= warp_idx < NumWarpsPerWarpGroup; + } + } + else { + if constexpr (IsInterWarpReductionNeeded) { + // When non-cooperative kernels apply inter warp reduce, they are with + // SF output rule as below : + // 1. warp 0 shares with 2 and 1 shares with 3 within each warpgroup. + // 2. warps 0 and 1 of the first warpgroup and 4 and 5 of the second + // warpgroup need to write output sf. + write_sf &= ((warp_idx < 2) || (warpgroup_idx == 1 && warp_idx < 6)); + } + } + + if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { + copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using Sm1xxBlockScaledOutputConfig = cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } + + auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + + static_assert(size<1>(EpilogueTile{}) && ((size<1>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), "Epilogue Tile N should be pow of 2"); + Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) + Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + + auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); + + // Fetch and compute these during initialization + Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); + ElementCompute norm_constant = mNormConst(_0{},_0{},l); + ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); + ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); + ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); + + Tensor sAmaxs = make_tensor( + make_smem_ptr(smem_aux), + make_layout(make_shape(Int{}, Int{})) + ); + + return ConsumerStoreCallbacks( + cute::move(tCrSFD), + cute::move(tCgSFD), + cute::move(sAmaxs), + args.tCcD, + args.residue_tCcD, + params_ptr, + tile_coord_mn, + norm_constant, + norm_constant_scaled_down, + args.thread_idx, + args.tiled_copy); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int SFVecSize, + class EpilogueTile, + class CtaTileShapeMNK, + int FragmentSize, + class ElementOutput, + class ElementCompute, + class ElementBlockScaleFactor, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct Sm120BlockScaleFactorColStore { + + static_assert(size<0>(EpilogueTile{}) % SFVecSize == 0, "EpilogueTileN should be divisible by SFVecSize"); + static_assert(size<0>(EpilogueTile{}) / SFVecSize == 1 or + size<0>(EpilogueTile{}) / SFVecSize == 2 or + size<0>(EpilogueTile{}) / SFVecSize == 4, + "Possible store in interleaved 4B aligned format"); + + static constexpr int NumWarpgroups = 2; + static constexpr int NumSyncWarps = NumWarpsPerWarpGroup * NumWarpgroups; + static constexpr int NumThreadsPerQuad = 4; + static constexpr int NumSyncElementsCrossWarp = NumSyncWarps * NumThreadsPerQuad; + struct SharedStorage { + array_aligned smem_aux; + }; + + using NormalConstStrideMNL = Stride<_0,_0,int64_t>; + + struct Arguments { + ElementBlockScaleFactor* ptr_scale_factor = {}; + // A matrix wide constant value to scale the output matrix + // Avoids generating small FP4 values. + ElementCompute const* norm_constant_ptr = {}; + NormalConstStrideMNL norm_constant_stride = {}; + }; + using Params = Arguments; + + using UnderlyingElementBlockScaleFactor = cute::remove_pointer_t; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = (M % SFVecSize == 0); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: [EVT Sm120BlockScaleFactorColStore] N-dim should be divisible by SFVecSize.\n"); + } + return implementable; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm120BlockScaleFactorColStore() { } + + CUTLASS_HOST_DEVICE + Sm120BlockScaleFactorColStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) + , smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr = nullptr; + ElementCompute *smem_aux = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class GTensor, + class STensor, + class CoordGTensor, + class ThrResidue, + class TileCoordMN, + class ElementType, + class TiledCopy_ + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rSFD_, + GTensor&& tC_gSFD_, + STensor&& sAmaxs_, + CoordGTensor tC_cSFD_, + ThrResidue residue_tC_cSFD_, + Params const* params_ptr_, + TileCoordMN tile_coord_mn_, + ElementType norm_constant_, + ElementType norm_constant_scaled_down_, + int thread_idx_, + TiledCopy_ const&) + : tC_rSFD(cute::forward(tC_rSFD_)) + , tC_gSFD(cute::forward(tC_gSFD_)) + , sAmaxs(cute::forward(sAmaxs_)) + , tC_cSFD(tC_cSFD_) + , residue_tC_cSFD(residue_tC_cSFD_) + , params_ptr(params_ptr_) + , norm_constant(norm_constant_) + , norm_constant_scaled_down(norm_constant_scaled_down_) + , tile_coord_mn(tile_coord_mn_) + , thread_idx(thread_idx_) {} + + static_assert(is_same_v); + RTensor tC_rSFD; + GTensor tC_gSFD; + STensor sAmaxs; + CoordGTensor tC_cSFD; + ThrResidue residue_tC_cSFD; + Params const* params_ptr; + ElementCompute norm_constant; + ElementCompute norm_constant_scaled_down; + TileCoordMN tile_coord_mn; + int thread_idx; + static constexpr int NumCollaboratingThreads = decltype(size(TiledCopy_{}))::value; + static_assert(NumCollaboratingThreads % NumThreadsPerWarpGroup == 0); + static constexpr int NumCollaboratingWarpGroups = NumCollaboratingThreads / NumThreadsPerWarpGroup; + static_assert(NumCollaboratingWarpGroups == 2, + "SM120 epilogue currently only supports two warp groups collaborating."); + static_assert(SFVecSize == 16 || SFVecSize == 32 || SFVecSize == 64, "SF vector size must be either 16, 32 or 64."); + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, + int epi_v, + int epi_m, + int epi_n, + Array const& frg_input) { + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(SmemTensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + /* + Accumulator fragments are distributed across threads/quads in different warps. For column major, the + reduction happens along M dimension. For SFVector = 32, we have: + + 8 elements 8 elements 8 elements 8 elements + + <----------------------><----------------------><----------------------><----------------------> + | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 + | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 + | ... ... ... ... + 1 | Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 + 6 | Warp 0 Quad 0 Warp 4 Quad 0 Warp 0 Quad 0 Warp 4 Quad 0 + | Warp 0 Quad 1 Warp 4 Quad 1 Warp 0 Quad 1 Warp 4 Quad 1 + | ... ... ... ... + + Warp 0 Quad 7 Warp 4 Quad 7 Warp 0 Quad 7 Warp 4 Quad 7 + | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 + | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 + 1 | ... ... ... ... + 6 | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 + | Warp 1 Quad 0 Warp 5 Quad 0 Warp 1 Quad 0 Warp 5 Quad 0 + | Warp 1 Quad 1 Warp 5 Quad 1 Warp 1 Quad 1 Warp 5 Quad 1 + | ... ... ... ... + | Warp 1 Quad 7 Warp 5 Quad 7 Warp 1 Quad 7 Warp 5 Quad 7 + + + + In this case, colum-wise scale factors are cooperatively reduced across 8 threads from 2 warps. + Each column first computes its own, local absolute maximum and then shares this with the + corresponding threads in the other warp. In this case, a reduction through shared memory is needed. + + For SFVector = 64, the reduction happens inside 4 warps: warp 0/1/2/3 and warp 4/5/6/7. + */ + + // Accumulator fragments consist of two elements from two different columns of a 16x8 MMA output + static constexpr int RowsPerThreadAccFrag = 2; + static constexpr int ColsPerThreadAccFrag = 2; + static_assert(FragmentSize == (ColsPerThreadAccFrag * RowsPerThreadAccFrag)); + + static constexpr int NumThreadsPerCol = NumThreadsPerWarp / NumThreadsPerQuad; + constexpr int WarpsPerSF = SFVecSize / NumThreadsPerCol / ColsPerThreadAccFrag; + static_assert(WarpsPerSF == 1 || WarpsPerSF == 2 || WarpsPerSF == 4, "Only one, two or four warps are allowed in reduction."); + + auto warp_idx = thread_idx / NumThreadsPerWarp; + auto thread_idx_in_warp = thread_idx % NumThreadsPerWarp; + + cutlass::maximum_absolute_value_reduction amax_op; + cutlass::multiplies mul; + + auto synchronize = [&] () { + // When WarpsPerSF equals 1, data processing is inside warp, there is no needs to have the sync. + static constexpr bool NoSyncNeeded = (WarpsPerSF == 1); + if(NoSyncNeeded) + return; + cutlass::arch::NamedBarrier::sync(NumCollaboratingThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + }; + + CUTLASS_PRAGMA_UNROLL + for(int mma_in_epi = 0; mma_in_epi < size<1>(tC_rSFD)*size<2>(tC_rSFD); ++mma_in_epi) { + + CUTLASS_PRAGMA_UNROLL + for (int sf_id = 0; sf_id < ColsPerThreadAccFrag; ++sf_id) { + + // + // Compute amax for this scale factor + // + ElementCompute amax{0}; + + // Compute amax among vals owned by this thread for this vector + auto acc_frg = visit_results(mma_in_epi); + amax = amax_op(amax, acc_frg[sf_id]); + amax = amax_op(amax, acc_frg[sf_id + ColsPerThreadAccFrag]); + + // At this point, each thread has computed the amax of the values that it owns for this SF vector. + // We now need to compute the amax across threads. Because SM120 narrow-precision MMAs have 16x8 output + // size with a quad owning two rows, we know that 8 threads in one column will own all of the 16 elements + // to be reduced via amax. Therefore, we can use warp shuffle intrinsics among threads to compute the amax. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < NumThreadsPerCol; ++i) { + auto amax_other = __shfl_xor_sync(0xffffffff, amax, (i * NumThreadsPerQuad)); + amax = amax_op(amax, amax_other); + } + + // At this point, all threads in the quad have the amax for the elements of the accumulator owned by its + // threads that should be used in computing the amax for this SF. + if (thread_idx_in_warp < NumThreadsPerQuad && WarpsPerSF != 1) { + sAmaxs(thread_idx_in_warp, warp_idx) = amax; + } + + synchronize(); + + // Get the amax broadcasted by the warp with which we share. + // For cooperative kernels, when scale factor vector size is 32 (WarpsPerSF equals 2), + // warp 0 shares with 1, warp2 shares with 2, etc. + // When vector size is 64 (WarpsPerSF equals 4), warp 0 shares with 1/2/3, and 4 shares with 5/6/7. + // When vector size is 16, no needs to swap between warps. + if constexpr (2 == WarpsPerSF) { + auto amax_other = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); + amax = amax_op(amax, amax_other); + } + else if constexpr (4 == WarpsPerSF) { + auto amax_other1 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 1); + auto amax_other2 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 2); + auto amax_other3 = sAmaxs(thread_idx % NumThreadsPerQuad, warp_idx ^ 3); + amax = amax_op(amax, amax_other1); + amax_other2 = amax_op(amax_other2, amax_other3); + amax = amax_op(amax, amax_other2); + } + synchronize(); + + ElementCompute pvscale = mul(amax, norm_constant_scaled_down); + UnderlyingElementBlockScaleFactor qpvscale = NumericConverter{}(pvscale); + filter(tC_rSFD)(sf_id + mma_in_epi*ColsPerThreadAccFrag) = qpvscale; + + // + // Apply the scale factor to the output + // + ElementCompute qpvscale_rcp = [&]() { + if constexpr (cute::is_same_v) { + // UE8M0: Use integer subtraction to do the fast rcp in ue8m0 and then convert to float. + auto e8m0_qpvscale_rcp = cutlass::reciprocal_approximate{}(qpvscale); + return cutlass::NumericConverter{}(e8m0_qpvscale_rcp); + } + else { + // UE4M3: Do the rcp in fp32 data type. + auto qpvscale_up = cutlass::NumericConverter{}(qpvscale); + return cutlass::reciprocal_approximate_ftz{}(qpvscale_up); + } + }(); + + ElementCompute acc_scale = mul(norm_constant, qpvscale_rcp); + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + + // Compute quantized output values + visit_results(mma_in_epi)[sf_id ] = mul(acc_frg[sf_id ], acc_scale); + visit_results(mma_in_epi)[sf_id + ColsPerThreadAccFrag] = mul(acc_frg[sf_id + ColsPerThreadAccFrag], acc_scale); + } // end for sf_id + } // end for mma_in_epi + + // Since scale factors are computed cooperatively across two or four warps, we only need one thread from the + // cooperating column threads group to write out the data. + bool write_sf = (thread_idx_in_warp < NumThreadsPerQuad); + if constexpr (2 == WarpsPerSF) { + // Output warp {0, 2, 4, 6}. + write_sf &= ((warp_idx & 0x1) == 0); + } + else if constexpr (4 == WarpsPerSF) { + // Output warp {0, 4}. + write_sf &= ((warp_idx & 0x3) == 0); + } + else if constexpr (1 == WarpsPerSF) { + // Output warp {0, 1, ..., 7}. Keep write_sf as is. + } + + if (write_sf && elem_less(tC_cSFD(_0{}, _0{}, _0{}, epi_m, epi_n), residue_tC_cSFD)) { + copy_aligned(tC_rSFD, tC_gSFD(_, _, _, _0{}, _0{}, get<0>(tile_coord_mn) + epi_m, get<1>(tile_coord_mn) + epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + UnderlyingElementBlockScaleFactor* ptr_scale_factor = nullptr; + // If Ptr-Array/Grouped GEMM with BlockScaleFactor per batch/group + if constexpr (!cute::is_same_v) { + ptr_scale_factor = params_ptr->ptr_scale_factor[l]; + l = 0; + } + else { + ptr_scale_factor = params_ptr->ptr_scale_factor; + } + + static_assert(size<0>(EpilogueTile{}) && ((size<0>(EpilogueTile{}) & (size<1>(EpilogueTile{}) - 1)) == 0), + "Epilogue Tile N should be pow of 2"); + + auto epi_tile_mn = shape<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)); + Tensor mSFD = make_tensor(make_gmem_ptr(ptr_scale_factor), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(args.problem_shape_mnkl)); + + Tensor gSFD = local_tile(mSFD, args.epi_tile, make_coord(_, _,l)); // (EPI_M,EPI_N, #EPI_Ms, #EPI_Ns) + Tensor tCgSFD = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,#EPI_Ms, #EPI_Ns) + gSFD, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrSFD = make_tensor_like(take<0,3>(cute::layout(tCgSFD))); // (CPY,CPY_M,CPY_N) + + auto tile_coord_mn = make_coord(m * size<0>(epi_tile_mn), n * size<1>(epi_tile_mn)); + + // Fetch and compute these during initialization + Tensor mNormConst= make_tensor(make_gmem_ptr(params_ptr->norm_constant_ptr), make_layout(make_shape(M, N, L), params_ptr->norm_constant_stride)); + ElementCompute norm_constant = mNormConst(_0{},_0{},l); + ElementCompute fp_max = ElementCompute(cutlass::platform::numeric_limits::max()); + ElementCompute scale_down_factor = cutlass::reciprocal_approximate_ftz{}(fp_max); + ElementCompute norm_constant_scaled_down = cutlass::multiplies{}(norm_constant, scale_down_factor); + + Tensor sAmaxs = make_tensor( + make_smem_ptr(smem_aux), + make_layout(make_shape(Int{}, Int{})) + ); + + return ConsumerStoreCallbacks( + cute::move(tCrSFD), + cute::move(tCgSFD), + cute::move(sAmaxs), + args.tCcD, + args.residue_tCcD, + params_ptr, + tile_coord_mn, + norm_constant, + norm_constant_scaled_down, + args.thread_idx, + args.tiled_copy); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..95e8208686ead6606040ee280023a7f5b879b07b --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,2792 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + >; + using Operation = fusion::ScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombinationPtrArray = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastPtrArray>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltActPtrArray = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltActPtrArray { + + using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerRowBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerColBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltAct = + Sm90EVT, + Sm90LinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltAct = + Sm90EVT, + Sm90LinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// Aux = alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerRowBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per_col bias) +// Aux = alpha * acc + beta * C + per_col bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerColBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-row alpha * acc + per-row beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBiasEltAct = + Sm90EVT, + Sm90PerRowLinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + using StrideAlpha = Stride; + using StrideBeta = Stride; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + StrideAlpha dAlpha = {bool(1), _0{}, 0}; + StrideBeta dBeta = {bool(1), _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-col alpha * acc + per-col beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBiasEltAct = + Sm90EVT, + Sm90PerColLinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(per-col alpha * acc + per-column bias) + per-col beta * C +template< + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColResAddPerColBiasEltAct = + Sm90EVT, // beta * C + activation(alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // activation(alpha * acc + bias) + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + > + >; + + template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColResAddPerColBiasEltAct< + CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColResAddPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + activation(alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // unary op : activation(alpha * acc + bias) + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + +// We only apply the scaling factor if output is fp8 +template +struct ScaleOutOp { template using Op = cutlass::first; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; + +template +using amax = cutlass::maximum_absolute_value_reduction; // propogate nans + +}; // end namespace detail + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerRowBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerColBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias + Sm90ScaledLinCombPerColBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerColBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + Sm90AuxLoad // aux + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltActDePerRowBias = + Sm90EVT, // Identity for final conversion + Sm90EVT, AlignmentBias>, + Sm90LinCombDeEltAct + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias* dbias_ptr = nullptr; + StrideBias dDbias = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : identity/convert + { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }, // end binary op + {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce + }, // end unary op + {} // unary args : identity/convert + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombTopKSoftmaxCol = + Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int TopK, + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombTopKSoftmaxCol, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombTopKSoftmaxCol { + + using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombTopKSoftmaxCol; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Grouped Wgrad Conv +template< + class GroupsPerTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombinationGroupedWgrad = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast>, // alpha + Sm90AccFetchGroupedWgrad // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class GroupsPerTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinearCombinationGroupedWgrad, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombinationGroupedWgrad::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombinationGroupedWgrad; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + //ElementScalar groups = ElementScalar(1); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template > +struct get_element_aux { + using type = void; +}; + +template +struct get_element_aux> { + using type = typename FusionOpOrCallbacks::ElementAux; +}; + +template +struct get_element_aux, cute::void_t<>> { + using type = typename get_element_aux::type; +}; + +template +struct get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename FusionCallbacks::Operation; + public: + using type = typename get_element_aux::type; +}; +} // namespace cutlass:epilogue::fusion::detail + +template +using get_element_aux_t = typename detail::get_element_aux::type; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ae63a7675c12dc4329374815da4d081a6bd885ee --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,842 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/detail/helper_macros.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// N-nary Elementwise Compute Operation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// The template argument provided for ComputeFn must be able to accept +// exactly one template parameter. In Standard C++, it's OK for +// ComputeFn to have other template parameters, as long as those have +// defaults. For example, the following struct Foo would work. +// +// template +// struct Foo { +// CUTLASS_HOST_DEVICE auto operator() (A a, B b); +// }; +// +// However, some compilers, such as Clang, require that the argument +// take _exactly_ one template parameter. This is nonstandard C++ +// behavior. One work-around for this case is to create a subclass +// with exactly one template parameter, and then use that subclass as +// the template argument. +// +// template +// struct FooHomogeneous : public Foo {}; +// +template< + template class ComputeFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class = void +> +struct Sm90Compute { +private: + using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; + + template + struct ComputeArguments { + using type = EmptyArguments; + }; + + // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters + template + struct ComputeArguments> { + using type = typename Fn::Arguments; + }; + +public: + struct SharedStorage { }; + + using Arguments = typename ComputeArguments>::type; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const&, Arguments const&) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90Compute() + : params() {} + + CUTLASS_HOST_DEVICE + Sm90Compute(Params const& params, SharedStorage const& shared_storage) + : params(params) {} + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Params const& params) + : params(params) {} + + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + return transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) CUTLASS_LAMBDA_FUNC_INLINE { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + return convert_input(frg_input); + }, + [&] (auto&&... cvt_frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { + using ComputeOutput = ComputeFn>; + ComputeOutput compute_output{}; + + if constexpr (cute::is_same_v) { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + return convert_output(compute_output(cvt_frg_inputs...)); + } + else { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + return convert_output(compute_output(cvt_frg_inputs..., params)); + } + } + ); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(params); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Performance Optimized Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// beta * C + Z +template < + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class InputScaleOp, // beta + class ElementSource, // C + class InputAddOp // Z +> +struct Sm90TreeVisitor< + Sm90Compute().is_zero())>>, + InputScaleOp, + Sm90SrcFetch, + InputAddOp +> : Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || + added_op.is_producer_load_needed(); + } + else { + return is_C_load_needed() || added_op.is_producer_load_needed(); + } + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); + auto const& src_op = get<1>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + return (not scale_op.is_zero() && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); + } + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) + : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } + + bool is_C_load_needed; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementZ = typename decltype(frg_added)::Element; + using ConvertZ = NumericArrayConverter; + using ConvertI = NumericArrayConverter; + ConvertZ convert_Z{}; + ConvertI convert_I{}; + + Array frg_I = convert_Z(frg_added); + + if constexpr (!is_void_v) { + Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementX = typename decltype(frg_scalar)::Element; + using ElementY = typename decltype(frg_source)::Element; + using ConvertX = NumericArrayConverter; + using ConvertY = NumericArrayConverter; + using ComputeI = multiply_add>; + ConvertX convert_X{}; + ConvertY convert_Y{}; + ComputeI compute_I{}; + + frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); + } + + return convert_I(frg_I); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + bool is_C_load_needed = this->is_C_load_needed(); + if (not is_C_load_needed) { + cute::clear(args.tCrC); + } + return ConsumerStoreCallbacks( + is_C_load_needed, std::move(callbacks_tuple)); + } +}; + +// ReLU with aux bit tensor dReLU/dZ +// Aux(i) = Z(i) >= 0 ? 1 : 0 +namespace detail { +// Placeholder node so we can retain standard EVT structure +template +struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } +}; +} // namespace detail + +// Specialization on the generic compute+aux EVT +template < + // Compute node + template class Activation, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + // Aux node + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment, + bool EnableNullptr, + // Input node + class InputOp +> +struct Sm90TreeVisitor< + Sm90Compute, cutlass::epilogue::thread::ReLu> || + cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, + Sm90TreeVisitor< + Sm90AuxStore< + Stages, + EpilogueTile, + cutlass::uint1b_t, + RoundStyle, + StrideMNL, + SmemLayoutAtom, + CopyOpR2S, + Alignment, + EnableNullptr + >, + InputOp + > +> : Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) + : params(params_), Impl(params_, shared_storage) {} + + Params const& params; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + GTensor&& tC_gAux, + CTensor tC_cAux, + ThrResidue residue_tC_cAux, + Params const& params, + CallbacksImpl&& impl) + : tC_rAux(cute::forward(tC_rAux)), + tC_gAux(cute::forward(tC_gAux)), + tC_cAux(tC_cAux), + residue_tC_cAux(residue_tC_cAux), + params(params), + CallbacksImpl(cute::forward(impl)) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tC_cAux; + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); + + // Compute activation + aux + using ElementInput = typename decltype(frg_input)::Element; + using ConvertInput = NumericArrayConverter; + using ConvertAux = PackPredicates; + using ComputeOutput = Activation; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput relu{}; + ConvertAux convert_aux{}; + ConvertOutput convert_output{}; + + Array frg_compute = convert_input(frg_input); + bool frg_aux[FragmentSize]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + ElementCompute pre_relu = frg_compute[i]; + if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { + frg_compute[i] = relu(frg_compute[i], params_compute); + } + else { + frg_compute[i] = relu(frg_compute[i]); + } + if constexpr (cute::is_same_v) { + uint32_t aux; + asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else if constexpr (cute::is_same_v) { + uint32_t aux; + cutlass::half_t compute = frg_compute[i]; + asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else { + frg_aux[i] = frg_compute[i] == pre_relu; + } + } + + static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); + Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) + tC_rAux_frg(epi_v) = convert_aux(frg_aux); + + return convert_output(frg_compute); + } + + CUTLASS_DEVICE void + end() { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + callbacks_input.end(); + + // Nullptr is no-op + if constexpr (EnableNullptr) { + if (params_aux.ptr_aux == nullptr) { + return; + } + } + + // Compute vectorization + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + // Copy vectorizes into byte-aligned stores + if constexpr (V > 1 && V % 8 == 0) { + using VecType = uint_bit_t; + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); + } + // sub-byte vectorization, must serialize threads + else { + // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) + int lane_idx = canonical_lane_idx(); + Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < NumThreadsPerWarp; ++i) { + if (lane_idx == i) { + copy_if(tC_pAux, tC_rAux, tC_gAux); + } + __syncwarp(); + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // Unpack params + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(params_aux.ptr_aux); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params, cute::move(callbacks_impl)); + } +}; + +// Aux load for uint1b_t +template < + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + Stages, + EpilogueTile, + cutlass::uint1b_t, + StrideMNL, + SmemLayoutAtom, + CopyOpS2R, + Alignment, + EnableNullptr +> { + static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); + + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t const* ptr_aux = nullptr; + cutlass::uint1b_t null_default = cutlass::uint1b_t(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const&) + : params(params) { } + + Params const params; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ThrResidue residue_tC_cAux_, Params const& params_) + : tC_rAux(cute::forward(tC_rAux_)), + tC_gAux(cute::forward(tC_gAux_)), + tC_cAux(tC_cAux_), + residue_tC_cAux(residue_tC_cAux_), + params(params_) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tC_cAux; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { + using VecType = uint_bit_t; + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); + } + else { + Tensor tC_pAux = cute::lazy::transform(tC_cAux, [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux, tC_gAux, tC_rAux); + } + } + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + Tensor tC_pAux = cute::lazy::transform(tC_cAux(_,_,_,epi_m,epi_n), [&](auto const& c){ return elem_less(c, residue_tC_cAux); }); + copy_if(tC_pAux, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + } + } + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + using ElementRegister = typename remove_cvref_t::value_type; + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + return recast>(coalesce(tC_rAux))(epi_v); + } + else { + return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(params.ptr_aux); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + + // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences + constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; + Tensor tC_rAux = [&] () CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (V % 8 != 0) { + return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + } else { + return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + } + }(); + + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + fill(tC_rAux, params.null_default); + } + } + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params); + } +}; + +// dReLU specialization +template< + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct Sm90Compute< + cutlass::epilogue::thread::dReLU, + ElementOutput, + ElementCompute, + RoundStyle +> : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input, + Array const& frg_aux) { + using ConvertInput = NumericArrayConverter; + using ComputeOutput = cutlass::epilogue::thread::dReLU>; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..535d8b082d44ff796fe2efc4e1531b4a3dc2674c --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -0,0 +1,1492 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/helper_macros.hpp" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Fetch Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns accumulator +struct Sm90AccFetch : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return frg_acc; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks{}; + } +}; + +// Split tree visitor fetches intermediate results from temporary accumulators +using Sm90SplitTreeFetch = Sm90AccFetch; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns C +template +struct Sm90SrcFetch : Sm90VisitorImpl<> { + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return is_C_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return not is_void_v; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_void_v; + } + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(SrcTensor const& tCrC) + : tCrC(tCrC) {} + + SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tCrC)(epi_v); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // register type may differ from logical type so we can't assert matching types here + return ConsumerStoreCallbacks(args.tCrC); + } +}; + +// returns accumulator in Grouped Conv Wgrad +template +struct Sm90AccFetchGroupedWgrad : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + using GroupsPerTile = GroupsPerTile_; + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(int32_t thread_idx) + : thread_idx(thread_idx) { } + + int32_t thread_idx; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + + Array frg_acc_rst; + int warp_id = thread_idx / 32; + + // In Grouped Wgrad, only diagonal block data is valid and the others is wrong and useless. + // One block size is C/G x C/G. Note that C/G = Tile_N / GroupsPerTile. + // Copy diagonal block ACC into the first block Col which is the output tensor size Tile_M * C/G. + // Then we can store the valid output tensor tile directly. + if constexpr ( cute::is_same_v ) { + frg_acc_rst = frg_acc; + } + else if constexpr ( cute::is_same_v ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 16; i++) { + frg_acc_rst[i] = frg_acc[i + warp_id / 2 * 16]; + } + } + else if constexpr ( cute::is_same_v ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 8; i++) { + frg_acc_rst[i] = frg_acc[i + warp_id * 8]; + } + } + else if constexpr ( cute::is_same_v ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; i++) { + frg_acc_rst[i] = frg_acc[i + warp_id * 8 + i / 2 * 4]; + } + } + + return frg_acc_rst; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(args.thread_idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90AuxLoad { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using CopyOpG2S = + SM90_TMA_LOAD + ; + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), + take<0,2>(SmemLayoutTma{}))); + TMA_Aux tma_load_aux; + Element null_default = Element(0); + bool use_default = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto M_AUX = + size(M) + ; + Tensor tensor_aux = make_tensor(make_gmem_ptr(args.ptr_aux), make_layout(make_shape(M_AUX,N,L), append<3>(args.dAux, _0{}))); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(CopyOpG2S{}, tensor_aux, take<0,2>(SmemLayoutTma{})); + + bool use_default = false; + if constexpr (EnableNullptr) { + use_default = args.ptr_aux == nullptr; + } + + return Params{tma_load_aux, args.null_default, use_default}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params_ptr->use_default && params_ptr->null_default == Element(0)); + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) + : bGS_gAux(cute::forward(bGS_gAux)), + bGS_sAux(cute::forward(bGS_sAux)), + params_ptr(params_ptr) {} + + GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + return; + } + } + + if (issue_tma_load) { + // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size + constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bits_v / 8; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA load + constexpr uint16_t mcast_mask = 0; + int load_pipe_index = load_iteration % Stages; + copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), + bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + + ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); + Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + + return ProducerLoadCallbacks( + cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) + : tC_rAux(cute::forward(tC_rAux)), + tiled_s2r(tiled_s2r), + tSR_sAux(cute::forward(tSR_sAux)), + params_ptr(params_ptr) { } + + TiledS2R tiled_s2r; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + + using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); + Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) + + int load_pipe_index = load_iteration % Stages; + copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + + return tC_rAux_frg(epi_v); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor tC_gAux = sm90_partition_for_epilogue(mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + auto tiled_s2r = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); + } +}; + +template < + class Element, + class EpilogueTile, // Unused + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpS2R, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + 0, EpilogueTile, Element, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorG2R, + class RTensor, + class CTensorG2R, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, + RTensor&& tC_rAux, + CTensorG2R&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorG2R tC_gAux; + RTensor tC_rAux; + CTensorG2R tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); + + copy_if(tC_pAux_vec, tC_gAux_vec, tC_rAux_vec); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tC_rAux)(epi_v); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Broadcast Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL_ = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcast { + using StrideMNL = StrideMNL_; + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + if (get<2>(params_ptr->dScalar[0]) == 0) { + // Only 1 batch + return scalar == Element(0); + } + else { + // multiple batch + if (valid_scalar == false) { + // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. + return params_ptr->scalar_ptrs[0] == nullptr; + } + else { + // Check whether each batch is ZERO or not. + return scalar == Element(0); + } + } + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + bool valid_scalar = false; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (size<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + valid_scalar = true; + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } + + template + CUTLASS_DEVICE void + update_scalar(cute::tuple) { + // Only support multiple L-modes with fully-broadcast scalar + scalar = params_ptr->scalars[0]; + valid_scalar = true; + } +}; + +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL_ = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcastPtrArray { + using StrideMNL = StrideMNL_; + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + // producer load is needed if Element is not void + return !cute::is_void_v; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + return scalar == Element(0); + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Always refresh scalar with the current group index so per-group + // alpha/beta values (provided through pointer arrays) are loaded + // correctly even when the L-stride is zero. + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptr_arrays[0] != nullptr) { + // Pointer-array variant: each entry already points to the scalar of a group. + scalar = *(params_ptr->scalar_ptr_arrays[0][l_coord]); + } + else if (params_ptr->scalar_ptrs[0] != nullptr) { + // Strided pointer variant. + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // Literal fallback. + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + + if (params_ptr->scalar_ptr_arrays[i] != nullptr) { + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][l_coord])); + } + else if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +[[deprecated("row broadcast only uses 0 stages")]] constexpr int +compute_row_broadcast_stages() { + return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; +} + +} + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class ElementInput_, + class ElementCompute = cute::remove_pointer_t, + class StrideMNL_ = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v>, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90RowBroadcast { + using StrideMNL = StrideMNL_; + // Get base element input type. + using ElementInput = cute::remove_pointer_t; + // Check if input is an array of pointers. + static constexpr bool IsArrayOfPointers = is_same_v; + using PtrRowType = cute::conditional_t; + + static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + struct Arguments { + PtrRowType ptr_row = nullptr; + ElementInput null_default = ElementInput(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), is_zero_(false), + smem(const_cast(shared_storage.smem.data())) { + auto const& [stride_M, stride_N, stride_L] = params.dRow; + // Nullptr default + if (EnableNullptr && params.ptr_row == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { + if constexpr (!IsArrayOfPointers) { + is_zero_ = params.ptr_row[0] == ElementInput(0); + } + } + } + + Params params; + bool is_zero_ = false; + ElementInput *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_zero_; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + Residue residue_cRow_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , residue_cRow(residue_cRow_) + , params(params_) { + } + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + Residue residue_cRow; // (m, n) + Params const& params; + + CUTLASS_DEVICE void + begin() { + bool is_nullptr = EnableNullptr && params.ptr_row == nullptr; + + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (not is_nullptr && elem_less(tGS_cRow_flt(i), residue_cRow)) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); // issue async gmem to smem load + } + else { + tGS_sRow_flt(i) = params.null_default; // fill OOB values so smem to RF load can issue without predication + } + } + } + + CUTLASS_DEVICE bool + begin_sync_needed() const { + return true; // Ensure visibility of async gmem to smem loads + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); + copy_aligned(tSR_sRow_flt, tSR_rRow_flt); + + constexpr int FrgSize = size(tSR_rRow_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); + Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); + ConvertInput convert_input{}; + + tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + auto layout_N = [&] () CUTLASS_LAMBDA_FUNC_INLINE { + auto shape_N = get<1>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_N = repeat_like(shape_N, int(0)); + if (get<1>(params.dRow) == bool(1)) { + stride_N = transform_leaf(compact_major(shape_N), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_N, stride_N); + } + else { + return make_layout(shape_N); + } + }(); + + auto layout_M = make_layout(M, repeat_like(M, _0{})); + auto layout_L = make_layout(L, get<2>(params.dRow)); + ElementInput const* ptr_row = nullptr; + if constexpr(IsArrayOfPointers) { + if (!(EnableNullptr && params.ptr_row == nullptr)) { + ptr_row = params.ptr_row[l]; + } + } else { + ptr_row = params.ptr_row; + } + Tensor mRow = make_tensor(make_gmem_ptr(ptr_row), make_layout(layout_M,layout_N,layout_L)); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + Tensor tGS_cRow = thr_g2s.partition_S(args.cD); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.residue_cD, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class ElementInput_, + class ElementCompute = cute::remove_pointer_t, + class StrideMNL_ = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v>, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90ColBroadcast { + using StrideMNL = StrideMNL_; + // Get base element input type. + using ElementInput = cute::remove_pointer_t; + // Check if input is an array of pointers. + static constexpr bool IsArrayOfPointers = is_same_v; + using PtrColType = cute::conditional_t; + + static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + PtrColType ptr_col = nullptr; + ElementInput null_default = ElementInput(0); + StrideMNL dCol = {}; + }; + + struct Params { + PtrColType ptr_col = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dCol = {}; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_zero_; + } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), is_zero_(false) { + auto const& [stride_M, stride_N, stride_L] = params.dCol; + // Nullptr default + if (EnableNullptr && params.ptr_col == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { + if constexpr (!IsArrayOfPointers) { + is_zero_ = params.ptr_col[0] == ElementInput(0); + } + } + } + + Params params; + bool is_zero_; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) + : tCgCol(tCgCol_), + tCrCol(tCrCol_), + tCcCol(tCcCol_), + residue_tCcCol(residue_tCcCol_), + params(params_) { + if (EnableNullptr && params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + } + } + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcCol; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (EnableNullptr && params.ptr_col == nullptr) { + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tCgCol_flt = filter_zeros(tCgCol); + Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); + Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); + + constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { + using VecType = uint_bit_t>; + Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); + Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); + Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); + Tensor tCpCol_vec = cute::lazy::transform(tCcCol_vec, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); + copy_if(tCpCol_vec, tCgCol_vec, tCrCol_vec); + } + else { + Tensor tCpCol_flt = cute::lazy::transform(tCcCol_flt, [&](auto const& c){ return elem_less(c, residue_tCcCol); }); + copy_if(tCpCol_flt, tCgCol_flt, tCrCol_flt); + } + + constexpr int FrgSize = size(tCrCol_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); + Tensor tCrCol_compute_frg = recast(filter(tCrCol)); + ConvertInput convert_input{}; + + tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + auto layout_M = [&] () CUTLASS_LAMBDA_FUNC_INLINE { + auto shape_M = get<0>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_M = repeat_like(shape_M, int(0)); + if (get<0>(params.dCol) == bool(1)) { + stride_M = transform_leaf(compact_major(shape_M), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_M, stride_M); + } + else { + return make_layout(shape_M); + } + }(); + + auto layout_N = make_layout(N, repeat_like(N, _0{})); + auto layout_L = make_layout(L, get<2>(params.dCol)); + ElementInput const* ptr_col = nullptr; + if constexpr(IsArrayOfPointers) { + if (!(EnableNullptr && params.ptr_col == nullptr)) { + ptr_col = params.ptr_col[l]; + } + } else { + ptr_col = params.ptr_col; + } + Tensor mCol = make_tensor(make_gmem_ptr(ptr_col), make_layout(layout_M,layout_N,layout_L)); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + Tensor mCol_static = make_tensor(make_gmem_ptr(ptr_col), make_layout(make_layout(M),layout_N,layout_L)); + Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix broadcast +// Only need to redefine this if we can multicast across cluster L +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +using Sm90MatrixBroadcast + = Sm90AuxLoad; + +namespace detail { + +template +struct IsScalarBroadcast { + static constexpr bool value = false; +}; + +template +struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { + static constexpr bool value = true; +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..06ad8082e57cedf4d16aecdad8a995e838e1c93e --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,1722 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + FloatRoundStyle RoundStyle, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90AuxStore { + using ElementAux = Element; + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), + SmemLayoutTma{})); + TMA_Aux tma_store_aux; + bool is_nullptr = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + bool is_nullptr = false; + if constexpr (EnableNullptr) { + is_nullptr = args.ptr_aux == nullptr; + } + + typename Params::TMA_Aux tma_store_aux; + if (not is_nullptr) { + Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); + tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); + } + + return {tma_store_aux, is_nullptr}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class TiledR2S, + class STensorR2S, + class STensorS2G, + class GTensorS2G + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + TiledR2S tiled_r2s, + STensorR2S&& tRS_sAux, + STensorS2G&& bSG_sAux, + GTensorS2G&& bSG_gAux, + Params const* params_ptr) + : tiled_r2s(tiled_r2s), + tC_rAux(cute::forward(tC_rAux)), + tRS_sAux(cute::forward(tRS_sAux)), + bSG_sAux(cute::forward(bSG_sAux)), + bSG_gAux(cute::forward(bSG_gAux)), + params_ptr(params_ptr) {} + + TiledR2S tiled_r2s; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) + STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) + GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); + Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) + + if (issue_smem_store) { + int store_pipe_index = store_iteration % Stages; + copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); + } + } + + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + if (issue_tma_store) { + // Issue the TMA store + int store_pipe_index = store_iteration % Stages; + copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) + + ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); + Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), + tiled_r2s, + cute::move(tRS_sAux), + cute::move(bSG_sAux), + cute::move(bSG_gAux), + params_ptr); + } +}; + +template < + class Element, + class EpilogueTile, // Unused + FloatRoundStyle RoundStyle, + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpR2S, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxStore< + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorR2G, + class RTensor, + class CTensorR2G, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensorR2G&& tC_gAux, + RTensor&& tC_rAux, + CTensorR2G&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorR2G tC_gAux; + RTensor tC_rAux; + CTensorR2G tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux(_,_,_,epi_m,epi_n)), MCL.compose(Int{}))); + Tensor tC_pAux_vec = cute::lazy::transform(tC_cAux_vec, [&](auto const& c){ return elem_less(c, problem_shape_mnl); }); + + copy_if(tC_pAux_vec, tC_rAux_vec, tC_gAux_vec); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Reduction Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar reduction +template < + template class RegReduceFn, + template class GmemReduceFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_0,_0>, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90ScalarReduction { +private: + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); + +public: + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_scalar = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + #if !defined(CUTLASS_SKIP_REDUCTION_INIT) + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); + if (args.ptr_scalar != nullptr) { + return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); + } + } + #endif + + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + int l_coord, + CTensor tCcScalar, + ThrResidue residue_tCcScalar, + Params const& params) + : scalar(params.reduction_identity), + l_coord(l_coord), + tCcScalar(tCcScalar), + residue_tCcScalar(residue_tCcScalar), + params(params) {} + + ElementCompute scalar; + int l_coord; + CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcScalar; + Params params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_tCcScalar)) { + scalar = reduce_input(scalar, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return; + } + } + + using ConvertI = NumericConverter; + using ReduceInput = GmemReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); + reduce_input(ptr_scalar, convert_I(scalar)); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks( + get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_tCcD, params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Row vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_row is assumed to point to a compact n-major (ceil_div(M,CTA_M), round_nearest(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true, + // Indicate the parameter order when calling RegReduceFn + // Seq length equals the number of RegReduceFn parameters + // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` + class RegReduceSeq = cute::seq<0, 1> +> +struct Sm90RowReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_row = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dRow = {}; + }; + + struct Params { + void* ptr_row = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dRow = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + else { + reduction_buffer = reinterpret_cast(args.ptr_row); + } + + return { + args.ptr_row, + args.reduction_identity, + args.dRow, + reduction_buffer, + tile_counters + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); + if (args.ptr_row != nullptr) { + return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90RowReduction() { } + + CUTLASS_HOST_DEVICE + Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return cute::get<0>(cute::make_tuple(frg_inputs...)); + } + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); + + if constexpr (VisitCheckOOB) { + using ReduceInput = RegReduceFn; + ReduceInput reduce_input{}; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + return ElementCompute(frg_input[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + } + else { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + using ReduceInput = RegReduceFn>; + ReduceInput reduce_input{}; + Tensor tCrRow_mn_frg = recast>(tCrRow_mn); + + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); + tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + using RegFragArr = Array, RegFragArraySize>; + ConvertInput convert_input{}; + return convert_input(reinterpret_cast(frg_input)[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + return cute::get<0>(cute::make_tuple(frg_inputs...)); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cRow(_0{},_0{}), residue_cRow)) { + return; + } + + int lane_m = get<0>(lane_mn); + [[maybe_unused]] bool is_reduced_lane = lane_m == 0; + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + Tensor tCrRow_frg = recast(filter(tCrRow)); + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + + auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); + constexpr bool SwapShuffle = FrgSizePerLaneM > 0; + + // + // Swap Shuffle + // + // The normal way to reduction among threads: + // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. + // After each step of reduction, a half of threads won't work in the following steps. + // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). + // + // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, + // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. + // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. + // We can recursively do this until the problem size is 1. + // + if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors + Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) + CUTLASS_PRAGMA_UNROLL + for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < m; ++r) { + auto frg_A = tCrRow_frg_(_,r); + auto frg_B = tCrRow_frg_(_,r + m); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(frg_A); ++v) { + // Step1: swap + if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second + cutlass::swap(frg_A(v), frg_B(v)); + } + + // Step2: shuffle + uint64_t frg_shfl = reinterpret_cast(frg_A(v)); + // each half of threads get a half of data from the other half of threads + frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); + + // Step3: reduction + frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); + } + } + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); + tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + } + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrRow_flt = filter_zeros(tCrRow); + Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + + Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgRow_flt = filter_zeros(tCgRow); + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if constexpr (SwapShuffle) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FltFrgSizePerLaneM; ++i) { + int idx = lane_m * FltFrgSizePerLaneM + i; + // Only care about OOB for N mode + if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { + reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); + } + } + } + else { + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + if (elem_less(tCcRow_flt(i), residue_tCcRow)) { + reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + } + } + } + } + sync_fn(); + } + + // + // 2. One warp in M, skip threadblock smem reduction + // + else if constexpr (decltype(size<0>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCgBuf_flt = recast(filter(tCgBuf)); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + copy_aligned(tCrRow, recast(tCgBuf)); + } + } + sync_fn(); + } + + // + // 2. Multiple warps in M, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCsBuf_flt = filter(tCsBuf); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + copy_aligned(tCrRow, tCsBuf); + } + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; + + constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; + using FragmentSmemArray = Array; + + // Do the threadblock smem reduction + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { + FragmentSmemArray frg_smem; + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { + int FragsCurrRows = reduction_rows * FragsPerRow; + frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); + } + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { + frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); + } + } + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[n], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_ml) * size<3>(gBuf_ml) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gRow_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + Tensor tRgBuf_ml = gBuf_ml(_0{},n,_,_); + ElementCompute output = tRgBuf_ml(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { + output = reduce_output(output, tRgBuf_ml(ml)); + } + if (elem_less(cRow(_0{},n), residue_cRow)) { + gRow_l(_0{},n,_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + bool do_store = elem_less(cRow(_0{},n), residue_cRow); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_ml); ++l) { + Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); + ElementCompute output = tRgBuf_m(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 1; m < size(tRgBuf_m); ++m) { + output = reduce_output(output, tRgBuf_m(m)); + } + if (do_store) { + gRow_l(_0{},n,l) = convert_output(output); + } + } + } + } + + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); // (M,N,L) + Tensor gRow_l = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gRow_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + fill(tCrRow, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_0{}, _1{})); + auto block_shape = ceil_div(make_shape(M,N,L), shape(gBuf_layout)); // (M_CNT, N_CNT, L_CNT) + + // Let the M_CNT (the num of partial reduction results) become the outer mode + Layout block_layout = make_layout(block_shape, make_stride(get<1>(block_shape), _1{}, get<0>(block_shape) * get<1>(block_shape))); + Layout mBuf_layout = blocked_product(gBuf_layout, block_layout); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_ml = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(_,n,_)); // (CTA_M,CTA_N,REST_M,L) + Layout sBuf_layout = blocked_product(gBuf_layout, // (CTA_M,CTA_N,WARPS_M) + make_layout(make_shape(_1{},_1{},size<0>(warp_layout_MN)))); + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(cute::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Col vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true +> +struct Sm90ColReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dCol = {}; + }; + + struct Params { + void* ptr_col = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dCol = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + else { + reduction_buffer = reinterpret_cast(args.ptr_col); + } + + return { + args.ptr_col, + args.reduction_identity, + args.dCol, + reduction_buffer, + tile_counters + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); + if (args.ptr_col != nullptr) { + return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return frg_input; + } + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (!VisitCheckOOB || elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_tCcCol)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + + // Runtime nullptr is noop + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + Tensor tCrCol_frg = recast(filter(tCrCol)); + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); + tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + bool is_reduced_lane = get<1>(lane_mn) == 0; + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + + Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgCol_flt = filter_zeros(tCgCol); + + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrCol_flt); ++i) { + if (elem_less(tCcCol_flt(i), residue_tCcCol)) { + reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); + } + } + } + sync_fn(); + } + + // + // 2. One warp in N, skip threadblock smem reduction + // + else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + copy_aligned(tCrCol, recast(tCgBuf)); + } + sync_fn(); + } + + // + // 2. Multiple warps in N, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + copy_aligned(tCrCol, tCsBuf); + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; + + // Do the threadblock smem reduction + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { + int FragsPerReduction = reduction_cols * FragsPerCol; + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); + sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // Do final smem reduction and dump to gmem workspace + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gCol_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); + ElementCompute output = tRgBuf_nl(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { + output = reduce_output(output, tRgBuf_nl(nl)); + } + if (elem_less(cCol(m,_0{}), residue_cCol)) { + gCol_l(m,_0{},_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + bool do_store = elem_less(cCol(m,_0{}), residue_cCol); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_nl); ++l) { + Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); + ElementCompute output = tRgBuf_n(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 1; n < size(tRgBuf_n); ++n) { + output = reduce_output(output, tRgBuf_n(n)); + } + if (do_store) { + gCol_l(m,_0{},l) = convert_output(output); + } + } + } + } + + } + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) + Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + fill(tCrCol, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); + Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) + Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix reduction +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class CopyOpR2S, + class SmemLayoutAtom, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90MatrixReduction; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..93720f8d3d71f3f4759463b5d40e604313b7e3a4 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -0,0 +1,1149 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree operation base implementation to enable composable fusions + for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/detail/helper_macros.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using cute::tuple; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partitioning Helpers +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class CtaTileMN, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + CtaTileMN cT, // (CTA_M,CTA_N,...) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); + Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) + if constexpr (ReferenceSrc) { + return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } + else { + return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } +} + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class Engine, class LayoutMNL, + class TileShapeMNK, + class TileCoordMNKL, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + Tensor mT, // (M,N,L) + TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) + TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + auto [m, n, k, l] = tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor tCcT = + sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return tCcT; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Visitor Implementation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Producer load callbacks, called by the epilogue load warp. +// Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation +// Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but +// are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead +// If this is non-empty, is_producer_load_needed must be true. +// +template +struct ProducerLoadCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + + +// +// Consumer store callbacks, called by the epilogue store warps. +// All operations must redefine this, with optional inheritance from this empty implementation. +// +template +struct ConsumerStoreCallbacksImpl { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin(); + } + ); + } + + // Is a thread sync needed after begin(). Allows chaining async copies across multiple nodes + CUTLASS_DEVICE bool + begin_sync_needed() const { + return cute::apply(callbacks_tuple, + [] (auto const&... callbacks) { + return (false || ... || callbacks.begin_sync_needed()); + } + ); + } + + // Start of subtile store iteration + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.begin_loop(epi_m, epi_n); + } + ); + } + + // Before visit callback. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After visit call. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + // + // visit_results is a rmem tensor that contains the results of visit() for an entire + // on the current epilogue subtile + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); + } + ); + } + + // After reduce call, before smem async fence. Smem stores usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // After smem async fence, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores + // other gmem stores can be placed in the reduce or postreduce entry points + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // End of subtile store iteration + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end_loop(epi_m, epi_n); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + callbacks.end(); + } + ); + } +}; + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class EpilogueTile +> +struct ProducerLoadArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + TiledMma tiled_mma; + EpilogueTile epi_tile; + int thread_idx; + + CUTLASS_DEVICE + ProducerLoadArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + EpilogueTile epi_tile, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + tiled_mma(tiled_mma), + epi_tile(epi_tile), + thread_idx(thread_idx) {} +}; + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class EpilogueTile, + class TiledCopy, + class CoordTensor, + class Residue, + class ThrCoordTensor, + class ThrResidue, + class ThrSrcTensor +> +struct ConsumerStoreArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + TiledMma tiled_mma; + EpilogueTile epi_tile; + TiledCopy tiled_copy; + CoordTensor cD; + Residue residue_cD; + ThrCoordTensor tCcD; + ThrResidue residue_tCcD; + ThrSrcTensor & tCrC; + int thread_idx; + + CUTLASS_DEVICE + ConsumerStoreArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + CoordTensor cD, + Residue residue_cD, + ThrCoordTensor tCcD, + ThrResidue residue_tCcD, + ThrSrcTensor & tCrC, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + tiled_mma(tiled_mma), + epi_tile(epi_tile), + tiled_copy(tiled_copy), + cD(cD), + residue_cD(residue_cD), + tCcD(tCcD), + residue_tCcD(residue_tCcD), + tCrC(tCrC), + thread_idx(thread_idx) {} +}; + +template +struct Sm90VisitorImplBase { + // Shared memory allocation + using SharedStorage = tuple; + // Host side fusion arguments + using Arguments = tuple; + // Device side fusion params (Kernel-entry API) + using Params = tuple; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { + using Op = cute::remove_cvref_t; + auto ret = Op::to_underlying_arguments(problem_shape, op_args, op_workspace); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return ret; + }, + [] (auto&&... op_params) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(op_params...); } + ); + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { + using Op = cute::remove_cvref_t; + return Op::can_implement(problem_shape, op_args); + }, + [&] (auto&&... implementable) CUTLASS_LAMBDA_FUNC_INLINE { + return (true && ... && implementable); + } + ); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { + using Op = cute::remove_cvref_t; + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + return round_nearest(op_workspace_size, MinWorkspaceAlignment); + }, + [&] (auto&&... op_workspace_size) CUTLASS_LAMBDA_FUNC_INLINE { + return (0 + ... + op_workspace_size); + } + ); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + // Initialize each operation's workspace, stopping at the first error + [&] (auto&& op, auto const& op_args) CUTLASS_LAMBDA_FUNC_INLINE { + if (status != Status::kSuccess) { + return status; + } + + using Op = cute::remove_cvref_t; + status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream, cuda_adapter); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return status; + }, + // Return the final status + [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { return status; } + ); + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops(transform_apply(tuple{}, params, shared_storage, + [] (auto&& op, auto const& op_params, auto&& op_storage) CUTLASS_LAMBDA_FUNC_INLINE { + using Op = cute::remove_cvref_t; + return Op(op_params, op_storage); + }, + [] (auto&&... ops) CUTLASS_LAMBDA_FUNC_INLINE { return cute::make_tuple(ops...); } + )) {} + + // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) + tuple ops; +}; + +template +struct Sm90VisitorImpl : Sm90VisitorImplBase { + + using Impl = Sm90VisitorImplBase; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + using Impl::ops; + + // + // Queries for kernel runtime + // + + // Is a specialized warp for producer TMA loads needed + // e.g. Aux tensor loads, broadcasts using TMA bulk copy + // This condition cannot change between work tiles because it is used + // to determine whether the load warp should exit early or not + // e.g. for batched beta this must always be true regardless of current batch idx + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { + return (false || ... || op.is_producer_load_needed()); + } + ); + } + + // Is a producer TMA load specifically for C needed + // If this is true then is_producer_load_needed must also be true + // This condition can change between work tiles because it is only used + // to determine whether the TMA and smem loads for C of a given tile should happen + // e.g. for batched beta this can be false depending on current batch idx + CUTLASS_DEVICE bool + is_C_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) CUTLASS_LAMBDA_FUNC_INLINE { + return (false || ... || op.is_C_load_needed()); + } + ); + } + + // Producer load callbacks factory + // All operations must redefine this, but most can just dispatch to the base impl + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return transform_apply(ops, + [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { + return op.get_producer_load_callbacks(args); + }, + [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ProducerLoadCallbacksImpl{callbacks_tuple}; + } + ); + } + + // Consumer store callbacks factory + // All operations must redefine this + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return transform_apply(ops, + [&] (auto& op) CUTLASS_LAMBDA_FUNC_INLINE { + return op.template get_consumer_store_callbacks(args); + }, + [] (auto&&... callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ConsumerStoreCallbacksImpl{callbacks_tuple}; + } + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Convenience aliases +using EmptyProducerLoadCallbacks = ProducerLoadCallbacksImpl>; +using EmptyConsumerStoreCallbacks = ConsumerStoreCallbacksImpl>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tree visitor +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sm90TreeVisitor : Sm90VisitorImpl { + + using Impl = Sm90VisitorImpl; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(ChildOps); + return cute::detail::tapply(callbacks_tuple, + [&] (auto& child_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) + }, + [&] (auto&&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + make_seq{} // restrict the transform to R-1 child ops, apply is for node op + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_impl = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// DAG visitors +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Most DAG fusions can be represented as a set of output trees with a common input tree +// The common input is first evaluated, then the result is passed as the acc fragment to the output trees +template +struct Sm90SplitTreeVisitor : Sm90VisitorImpl { + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + constexpr int Rm2 = sizeof...(AuxOutTrees); + cute::for_each(make_seq{}, // restrict the sequence to aux out trees + [&] (auto I) CUTLASS_LAMBDA_FUNC_INLINE { + get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + ); + + return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_impl = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // deducing the output type for all the nodes is tricky so we just convert them all to a common type + // if multiple compute types are needed then split into multiple subgraphs grouped by type + class ElementCompute, + class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node + class... Ops // in topological order, last op is the output. EdgeTuple must match this order +> +struct Sm90TopologicalVisitor : Sm90VisitorImpl { + static_assert(is_static_v); + static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(sizeof...(Ops) > 1); + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(Ops) - 1; + auto frg_compute_tuple = cute::repeat(Array{}); + + return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, + // Visit the first R-1 ops in topological order + [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) CUTLASS_LAMBDA_FUNC_INLINE { + frg_compute = cute::detail::apply(frg_compute_tuple, + // Compute the current op with children inputs + [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { + auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + using ElementOutput = typename decltype(frg_output)::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + + return convert_output(frg_output); + }, + // Get inputs in the sequence given by the children indices of the current op + edge_seq + ); + return frg_compute; // unused + }, + // Visit the last op + [&] (auto const&...ops) CUTLASS_LAMBDA_FUNC_INLINE { + return cute::detail::apply(frg_compute_tuple, + // Compute the last op with children inputs + [&] (auto const&... frg_inputs) CUTLASS_LAMBDA_FUNC_INLINE { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + // Get inputs in the sequence given by the children indices of the last op + get(EdgeTuple{}) + ); + }, + // Transform to visit R-1 ops, apply to visit last op + make_seq{} + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_impl = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(cute::move(callbacks_impl)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Base specializations so we can have standard layout params and simple aggregate initializers +namespace detail { + +template +struct Sm90VisitorImplBase { + + // Retain tuple for SharedStorage because empty structs have 1B alignment + // tuples use multiple inheritance, avoids this problem + using SharedStorage = tuple< + typename Op0::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + }; + + struct Params { + typename Op0::Params op_0; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage, + typename Op3::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + typename Op3::Arguments op_3; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + typename Op3::Params op_3; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + size_t op_2_workspace_size = Op2::get_workspace_size(problem_shape, args.op_2); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + uint8_t* op_3_workspace = op_2_workspace + op_2_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace), + Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2) && + Op3::can_implement(problem_shape, args.op_3); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)), + Op3(params.op_3, get<3>(shared_storage)) + }) {} + + tuple ops; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd378419567b1680c400ec38746211a577a3c409 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -0,0 +1,763 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Top-K + Softmax reduction across columns +// Performs a reduction of top-K values across N, and finally performs a softmax on them, +// and sets values not in the top-K to 0. +// +// Assumptions: +// 1. CTA_N >= N (single tile across N, the mode which is reduced) +// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one +// epilogue tile at a time.) +// 3. Top-K value is either 2 or 4. +// + +namespace detail { + +// Implementations for add to sorted list and merging sorted lists, +// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). +// Generic implementations may result in greater register use and branching, +// and should be avoided. +// Fast paths for Top-2 and Top-4 are written in inline PTX directly. + +CUTLASS_DEVICE +Array top_2_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx, %3, %4;\n" + " setp.gtu.f32 p, %2, %4;\n" + " selp.f32 %1, mx, %2, p;\n" + " selp.f32 %0, %2, %4, p;\n" + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_2_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .v2 .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx.x, %3, %4;\n" // max(a1, b0) + " max.f32 mx.y, %2, %5;\n" // max(a0, b1) + " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 + " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) + " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 + "}\n" : "=f"(out[0]), "=f"(out[1]) : + "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" // max(a3, b) + " .reg .pred p0;\n" // a0 > b + " .reg .pred p1;\n" // a1 > b + " .reg .pred p2;\n" // a2 > b + " max.f32 mx, %7, %8;\n" // max(a3, b) + " setp.gtu.f32 p0, %4, %8;\n" // a0 > b + " setp.gtu.f32 p1, %5, %8;\n" // a1 > b + " setp.gtu.f32 p2, %6, %8;\n" // a2 > b + " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 + " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b + " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 + " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b + " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 + " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mxa0b1;\n" // max(a0, b1) + " .reg .f32 mxa1b0;\n" // max(a1, b0) + + " .reg .f32 mxa2b0;\n" // max(a2, b0) + " .reg .f32 mxa1b1;\n" // max(a1, b1) + " .reg .f32 mxa0b2;\n" // max(a1, b1) + + " .reg .f32 mxa1b2;\n" // max(a1, b2) + " .reg .f32 mxa2b1;\n" // max(a2, b1) + " max.f32 mxa1b2, %5, %10;\n" + " max.f32 mxa2b1, %6, %9;\n" + + " .reg .f32 mxa3b0;\n" // max(a1, b2) + " .reg .f32 mxa0b3;\n" // max(a2, b1) + " max.f32 mxa3b0, %7, %8;\n" + " max.f32 mxa0b3, %4, %11;\n" + + " .reg .pred pa0b0;\n" // a0 > b0 + " .reg .pred pa1b0;\n" // a1 > b0 + " .reg .pred pa2b0;\n" // a2 > b0 + " .reg .pred pa0b1;\n" // a0 > b1 + " .reg .pred pa1b1;\n" // a1 > b1 + " .reg .pred pa0b2;\n" // a0 > b2 + " .reg .pred pb2a0;\n" // b1 > a0 + " .reg .pred pb1a0;\n" // b1 > a0 + + " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 + " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 + " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 + " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 + " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 + " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 + + " not.pred pb2a0, pa0b2;\n" + " not.pred pb1a0, pa0b1;\n" + + " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) + " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) + + " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) + " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) + " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) + + // a0 + " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 + + // a1 + " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) + + // a2 + " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case + " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 + " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 + + // a3 + " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases + " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case + " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 + " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), + "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); + return out; +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce_scalar(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce_scalar(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] < b) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b; + break; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] and b[0] are the largest elements in a[] and b[].) +template +CUTLASS_DEVICE +void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + int j = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] < b[j]) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b[j]; + ++j; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +Element topk_logsumexp(cutlass::Array a) { + // Do one less `exp`, because we know what its result will be. + // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. + // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) + // Compute m + log(1 + sum_{i != m}(x_i - x_m)) + Element sum = Element(1.0); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N; ++i) { + sum += fast_exp(a[i] - a[0]); + } + return a[0] + fast_log(sum); +} + +CUTLASS_DEVICE +float fast_masked_softmax(float value, float minimum, float logsumexp) { + float new_value; + asm volatile( + "{\n" + " .reg .pred p0;\n" + // value >= minimum + " setp.geu.f32 p0, %1, %2;\n" + + " .reg .f32 x_lse;\n" + " .reg .f32 %%f<11>;\n" + " .reg .b32 %%r<3>;\n" + + // x_lse = value - minimum + " sub.rn.f32 x_lse, %1, %3;\n" + + // exp(x_lse) + // The following is derived from a ptx dump of expf. + // exp requires a base conversion from exp2. + " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" + " cvt.sat.f32.f32 %%f2, %%f1;\n" + " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" + " add.f32 %%f4, %%f3, 0fCB40007F;\n" + " neg.f32 %%f5, %%f4;\n" + " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" + " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" + " mov.b32 %%r1, %%f3;\n" + " shl.b32 %%r2, %%r1, 23;\n" + " mov.b32 %%f8, %%r2;\n" + " ex2.approx.ftz.f32 %%f9, %%f7;\n" + " mul.f32 %%f10, %%f9, %%f8;\n" + + // Mask or softmax + " selp.f32 %0, %%f10, 0f00000000, p0;\n" + "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); + return new_value; +} + +template +CUTLASS_DEVICE +Element masked_softmax(Element value, Element minimum, Element logsumexp) { + if constexpr (is_same_v) { + // Inline PTX implementation + // Significantly reduces register requirements + return fast_masked_softmax(value, minimum, logsumexp); + } + else { + return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); + } +} + +} // namespace detail + +template < + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + int Alignment = 128 / sizeof_bits_v, + bool UseButterflyReduce = true +> +struct Sm90TopKSoftmaxColReduction { +private: + static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); + static_assert(TopK == 2 || TopK == 4, + "Fused Top-K + Softmax reduction only allows K=2 and K=4, because those cases have been performance-optimized. Other values of K can be enabled by removing this assertion, but they may come with serious performance implications." + ); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + // Reduction tensors + // We have two tensors for this EVT node: a reduction tensor and a tensor holding + // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax + // require different reductions, but those luckily overlap. Top-K obviously needs at least + // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log + // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the + // maximum of all x_i elements. Since safe softmax for any element x_i is computed as + // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) + // we can track logsumexp instead of tracking two variables (sum of exps and the max). + // In addition, subtracting logsumexp from any element and taking its exp is equivalent to + // computing its softmax. + // + // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the + // way at all, because any element not in the top-K is going to be masked out and set to 0. + // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and + // keep it, and the smallest element in the top-K for masking out non-top-K elements. + // + // This means that our final reduction result will always be 2 elements, regardless of the value + // of K: minimum of top-K, and logsumexp. + // + // For each reduction tensor, we define a new struct for readability. + + struct ReductionResult { + ElementCompute min_; + ElementCompute logsumexp_; + + CUTLASS_DEVICE + ReductionResult() { } + + CUTLASS_DEVICE + ReductionResult(ElementCompute min, ElementCompute logsumexp): + logsumexp_(logsumexp), min_(min) { } + + // Warp shuffle broadcast + CUTLASS_DEVICE + void shuffle_up_sync(uint32_t delta, int lane_id) { + static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); + uint64_t r = reinterpret_cast(*this); + r = __shfl_up_sync(0xFFFFFFFF, r, delta); + *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; + } + }; + + struct TopKResult { + Array top_k_; + + CUTLASS_DEVICE + TopKResult() { + top_k_.fill(-cutlass::platform::numeric_limits::infinity()); + } + + // This is where we do the "final" reduction, where we compute + // the logsumexp for softmax, keep the smallest value in top-K, + // and discard the rest. + CUTLASS_DEVICE + ReductionResult reduce_final() const { + return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); + } + + // Butterfly reduction + CUTLASS_DEVICE + void shuffle_xor_sync(int laneMask) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); + top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + + // Warp shuffle reduction + CUTLASS_DEVICE + void shuffle_down_sync(uint32_t delta) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); + top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + }; + +public: + struct SharedStorage { }; + + struct Arguments { }; + + struct Params { }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N && N <= epi_N && N >= TopK; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); + if (elem_less(thread_crd, residue_tCcCol)) { + TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); + detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, + // in order to reduce along modes in the `R2S` sublayout that correspond to N. + // This means we should modify and warp-reduce them according to their co-domain instead of + // their domain. Therefore we keep a filtered view of both and use them as necessary. + auto tCrTopK_f = filter(tCrTopK); + auto tCrSoftmax_f = filter(tCrSoftmax); + + // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the + // last element of Top-K, use the latter to mask the visited results, and the former + // to apply softmax. + // + // This gives us two options: reduce the Top-K with warp shuffles, have the reduced + // lanes compute logsumexp and pair it with the last Top-K element, and broadcast + // the result back using warp shuffles. + // + // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes + // compute their own logsumexp and skip the broadcast. + if constexpr (UseButterflyReduce) { + // + // 1. Butterfly reduction + // + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_xor_sync(j); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + else { + // + // 1. Warp shuffle reduction + // + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + bool is_reduced_lane = get<1>(lane_mn) == 0; + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + + // + // 3. Broadcast reduced values to all participants + // + CUTLASS_PRAGMA_UNROLL + for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); + } + } + } + + // + // 4. Re-visit and apply top-K and softmax + // + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { + auto& visit_frag = visit_results(epi_v); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + visit_frag[i] = detail::masked_softmax( + visit_frag[i], + tCrSoftmax(epi_v * FragmentSize + i).min_, + tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ + ); + } + } + + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // Reset reduced top-K values for next tile + // This must be done because we only assume a single epilogue tile across N, + // but not M. + fill(tCrTopK, TopKResult()); + } + + CUTLASS_DEVICE void + end() { } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + auto mn_shape = shape(typename decltype(args.tiled_copy)::Tiler_MN{}); + if constexpr (ReferenceSrc) { return right_inverse(args.tiled_copy.get_layoutS_TV()).with_shape(mn_shape); } + else { return right_inverse(args.tiled_copy.get_layoutD_TV()).with_shape(mn_shape); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + + // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. + static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); + + // Reduction layout + // We're assuming all elements in a row (over which we're performing the reduction) are + // visited in the same corresponding epilogue tile, and this is what allows us to apply the + // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. + // + // This presents a challenge, because the layout of the accumulated results is typically in + // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). + // This means that we still need to reduce this tensor along N. + // + // The solution is simple: we need to flatten the layout, identify modes that correspond to + // N and set their strides to 0, in order to map fragment indices corresponding to the same + // row back to the same element in the tensor. + // + // This requires some extra layout manipulation, which is as follows. + + // Create new accumulator layout with column broadcast + auto [M, N, K] = args.tile_shape_mnk; + auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); + auto gColReduce = make_tensor( + make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) + auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) + thr_mma.partition_C(gColReduce).layout()); + + // Tile the new accumulator tensor according to R2S + ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); + Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) + auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) + + // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // reduction tensor layout. + auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) + + Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + fill(tCrTopK, TopKResult()); + + auto args_tuple = make_tuple( + cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, + lane_layout_MN, lane_mn, + args.residue_cD, args.residue_tCcD); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h new file mode 100644 index 0000000000000000000000000000000000000000..8412b5037b3aacbca4d28b80b99839acb368d5df --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/include/third-party/cutlass/include/cutlass/epilogue/thread/activation.h @@ -0,0 +1,914 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This extends the contents of cutlass/functional.h with frequently used activation functions. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/constants.h" +#include "cutlass/complex.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// If kIsHeavy is a member, use it. Otherwise, assume that it's false. +template +struct kIsHeavy_member_or_false { + static constexpr bool value = false; +}; +template +struct kIsHeavy_member_or_false::type> { + static constexpr bool value = Op::kIsHeavy; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Identity operator +template +struct Identity { + static const bool kIsHeavy = false; + + CUTLASS_HOST_DEVICE + T operator()(T value) const { + return value; + } +}; + +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array value) const { + return value; + } +}; + +/// Scale operator +template +struct Scale { + struct Arguments { + using scale_type = T; + T scale = T(1); + }; + + CUTLASS_HOST_DEVICE + T operator()(T value, T scale) const { + multiplies mul; + return mul(scale, value); + } + + CUTLASS_HOST_DEVICE + T operator()(T value, Arguments args = Arguments()) const { + return this->operator()(value, args.scale); + } +}; + +template +struct Scale> { + using Arguments = typename Scale::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array values, T scale) const { + multiplies> mul; + return mul(scale, values); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array values, Arguments args = Arguments()) const { + return this->operator()(values, args.scale); + } +}; + +/// Specialization to compose other activations with a defined unary operator +/// e.g. Scale> +template